Skip to content

Commit

Permalink
Smaller correctness changes:
Browse files Browse the repository at this point in the history
- Ensure fake_file_progress.py works as expected.
- Optimize _event_instrumentation.py `consume` method
- Allow passing of message version to `reset_changes` to prevent losing important information
  • Loading branch information
cjavad committed Jan 29, 2025
1 parent 542e643 commit 1edb26d
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 70 deletions.
16 changes: 7 additions & 9 deletions simplyprint_ws_client/core/_event_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,15 @@ def consume(state: PrinterState) -> Tuple[List[ClientMsg], int]:
"""Consume state from mappings"""
changes = state.model_recursive_changeset

msg_kinds = {
_client_msg_map[k]: v for k, v in changes.items() if k in _client_msg_map}
# Sort by v
msg_kinds = sorted(
((_client_msg_map[k], v) for k, v in changes.items() if k in _client_msg_map), key=lambda x: x[1])

# Sort by v and make it a list
msg_kinds = [k for k in sorted(msg_kinds, key=msg_kinds.get)]
is_pending = state.config.is_pending()

msgs = []

is_pending = state.config.is_pending()

for msg_kind in msg_kinds:
for msg_kind, v in msg_kinds:
# Skip over messages that are not allowed to be sent when pending.
if is_pending and not msg_kind.msg_type().when_pending():
continue
Expand All @@ -230,6 +228,6 @@ def consume(state: PrinterState) -> Tuple[List[ClientMsg], int]:
continue

msgs.append(msg)
msg.reset_changes(state)
msg.reset_changes(state, v=v)

return msgs, -1
return msgs, -1 # max(v for _, v in msg_kinds)
67 changes: 35 additions & 32 deletions simplyprint_ws_client/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,49 +118,52 @@ def _should_schedule_client(self, client: Client, when: datetime):

async def _schedule_client(self, client: Client):
"""Schedule single client."""
was_allocated = self.manager.is_allocated(client)
try:
was_allocated = self.manager.is_allocated(client)

if not client.active:
if not was_allocated:
return
if not client.active:
if not was_allocated:
return

# Remove the connection from the multi printer.
if not await client.ensure_removed(self.settings.mode):
return
# Remove the connection from the multi printer.
if not await client.ensure_removed(self.settings.mode):
return

# Then we can deallocate the client from the connection.
await self.manager.deallocate(client)
await client.halt()
return
# Then we can deallocate the client from the connection.
await self.manager.deallocate(client)
await client.halt()
return

if not was_allocated:
await self.manager.allocate(client)
await client.init()
if not was_allocated:
await self.manager.allocate(client)
await client.init()

# Progress inner client state until we reach CONNECTED state.
# e.i. in multi printer mode until we receive the connected message.
if not await client.ensure_added(self.settings.mode, self.settings.allow_setup):
return
# Progress inner client state until we reach CONNECTED state.
# e.i. in multi printer mode until we receive the connected message.
if not await client.ensure_added(self.settings.mode, self.settings.allow_setup):
return

# Tick client.
last_ticked = self._last_ticked.get(client.unique_id, datetime.min)
now = datetime.now()
delta_tick = now - last_ticked
# Tick client.
last_ticked = self._last_ticked.get(client.unique_id, datetime.min)
now = datetime.now()
delta_tick = now - last_ticked

if delta_tick >= self._tick_rate_delta:
self._last_ticked[client.unique_id] = now
if delta_tick >= self._tick_rate_delta:
self._last_ticked[client.unique_id] = now

# TODO: Manage timeouts.
async with asyncio.timeout(5):
await client.tick(delta_tick)
# TODO: Manage timeouts.
async with asyncio.timeout(5):
await client.tick(delta_tick)

if not client.has_changes:
return
if not client.has_changes:
return

msgs, v = client.consume()
msgs, v = client.consume()

for msg in msgs:
await client.send(msg)
for msg in msgs:
await client.send(msg)
except Exception as e:
client.logger.error("Error while scheduling client", exc_info=e)

def _process_clients(self):
"""Schedule all clients for processing."""
Expand Down
19 changes: 16 additions & 3 deletions simplyprint_ws_client/core/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
]

import time
from typing import Optional, Literal
from typing import Union, List, Set, \
ClassVar
from typing import Optional, Literal, no_type_check, Union, List, Set, ClassVar

from pydantic import Field, field_validator

Expand Down Expand Up @@ -100,6 +98,14 @@ class FileProgressState(StateModel):
percent: float = 0.0
message: Optional[str] = None

@no_type_check
def __setattr__(self, key, value):
super().__setattr__(key, value)

# Reset the progress when the state changes away from downloading.
if key == 'state' and value != FileProgressStateEnum.DOWNLOADING:
self.percent = 0.0


class CpuInfoState(StateModel):
usage: Optional[float] = None
Expand Down Expand Up @@ -174,6 +180,7 @@ class JobInfoState(StateModel, validate_assignment=True):
def convert_to_always_true(cls, value):
return ExclusiveBool(value=value)

@no_type_check
def __setattr__(self, key, value):
"""Only one of the 4 fields can be True at a time."""
if not key in self.MUTUALLY_EXCLUSIVE_FIELDS:
Expand Down Expand Up @@ -264,6 +271,9 @@ def set_nozzle_count(self, count: int) -> None:
if count < 1:
raise ValueError("Nozzle count must be at least 1")

if len(self.tool_temperatures) == count:
return

if count > len(self.tool_temperatures):
for _ in range(count - len(self.tool_temperatures)):
self.tool_temperatures.append(model := TemperatureState())
Expand All @@ -278,6 +288,9 @@ def set_extruder_count(self, count: int) -> None:
if self.active_tool is not None and self.active_tool >= count:
self.active_tool = None

if len(self.material_data) == count:
return

if count > len(self.material_data):
for _ in range(count - len(self.material_data)):
self.material_data.append(model := MaterialState())
Expand Down
15 changes: 10 additions & 5 deletions simplyprint_ws_client/core/state/state_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

from .exclusive_bool import ExclusiveBool

if TYPE_CHECKING:
from ..client import Client

TStateModel = TypeVar("TStateModel", bound="StateModel")


Expand Down Expand Up @@ -97,16 +94,24 @@ def provide_context(self, ctx: Union['StateModel', weakref.ref]) -> None:
if isinstance(value := getattr(self, field), StateModel):
value.provide_context(self)

def model_reset_changed(self, *keys: str) -> None:
def model_reset_changed(self, *keys: str, v: Optional[int] = None) -> None:
"""
Reset the changed state, this will clear model_self_changed_fields.
"""

if not keys:
if not keys and v is None:
object.__setattr__(self, "model_self_changed_fields", {})
return

if not keys and v is not None:
object.__setattr__(self, "model_self_changed_fields",
{k: v2 for k, v2 in self.model_self_changed_fields.items() if v2 > v})
return

for key in keys:
if v is not None and self.model_self_changed_fields.get(key, -1) > v:
continue

self.model_self_changed_fields.pop(key, None)

@property
Expand Down
36 changes: 18 additions & 18 deletions simplyprint_ws_client/core/ws_protocol/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
"""Construct a dict with data based on the current state"""
raise NotImplemented()

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
"""
Custom reset logic for the message, typically used to
fully reset some state, but optionally comparing and only
Expand Down Expand Up @@ -592,7 +592,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
for key in state.info.model_fields:
yield key, getattr(state.info, key)

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.info.model_reset_changed()


Expand All @@ -601,7 +601,7 @@ class WebcamStatusMsg(ClientMsg[Literal[ClientMsgType.WEBCAM_STATUS]]):
def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
yield "connected", state.webcam_info.connected

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.webcam_info.model_reset_changed()


Expand All @@ -611,7 +611,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
for key in state.webcam_settings.model_changed_fields:
yield key, getattr(state.webcam_settings, key)

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.webcam_settings.model_reset_changed()


Expand All @@ -634,7 +634,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
{(f"firmware_{key}" if key != "name" else "firmware"): value for key in
state.firmware.model_fields if (value := getattr(state.firmware, key)) is not None})

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.firmware.model_reset_changed()


Expand All @@ -644,7 +644,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
for key in state.firmware_warning.model_changed_fields:
yield key, getattr(state.firmware_warning, key)

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.firmware_warning.model_reset_changed()


Expand All @@ -653,7 +653,7 @@ class ToolMsg(ClientMsg[Literal[ClientMsgType.TOOL]]):
def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
yield "new", state.active_tool

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.model_reset_changed("active_tool")


Expand All @@ -669,7 +669,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:

yield f"tool{i}", tool.to_list()

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.bed_temperature.model_reset_changed()

for tool in state.tool_temperatures:
Expand All @@ -691,7 +691,7 @@ class AmbientTemperatureMsg(ClientMsg[Literal[ClientMsgType.AMBIENT]]):
def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
yield "new", state.ambient_temperature.ambient

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.ambient_temperature.model_reset_changed()


Expand All @@ -707,7 +707,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:

yield "new", state.status

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.model_reset_changed("status")


Expand All @@ -728,7 +728,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:

yield key, value

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.job_info.model_reset_changed()

def dispatch_mode(self, state: PrinterState) -> DispatchMode:
Expand Down Expand Up @@ -770,7 +770,7 @@ class LatencyMsg(ClientMsg[Literal[ClientMsgType.LATENCY]]):
def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
yield "ms", state.latency.get_latency()

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.latency.model_reset_changed()


Expand All @@ -790,16 +790,16 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
if state.file_progress.state in (FileProgressStateEnum.DOWNLOADING, FileProgressStateEnum.STARTED):
yield "percent", state.file_progress.percent

def reset_changes(self, state: PrinterState) -> None:
state.file_progress.model_reset_changed()
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.file_progress.model_reset_changed(v=v)


class FilamentSensorMsg(ClientMsg[Literal[ClientMsgType.FILAMENT_SENSOR]]):
@classmethod
def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
yield "state", state.filament_sensor.state

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.filament_sensor.model_reset_changed()


Expand All @@ -808,7 +808,7 @@ class PowerControllerMsg(ClientMsg[Literal[ClientMsgType.PSU]]):
def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
yield "on", state.psu_info

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.psu_info.model_reset_changed()


Expand All @@ -818,7 +818,7 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
for key in state.cpu_info.model_changed_fields:
yield key, getattr(state.cpu_info, key)

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
state.cpu_info.model_reset_changed()

def dispatch_mode(self, state: PrinterState) -> DispatchMode:
Expand All @@ -842,6 +842,6 @@ def build(cls, state: PrinterState) -> TClientMsgDataGenerator:
if any(m.model_has_changed for m in state.material_data):
yield "materials", [m.model_dump(mode='json') if m.type is not None else None for m in state.material_data]

def reset_changes(self, state: PrinterState) -> None:
def reset_changes(self, state: PrinterState, v: Optional[int] = None) -> None:
for material in state.material_data:
material.model_reset_changed()
8 changes: 5 additions & 3 deletions simplyprint_ws_client/shared/sp/fake_file_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

__all__ = ['FakeFileProgress']

import asyncio
from typing import Optional

from simplyprint_ws_client.core.client import Client
from simplyprint_ws_client.core.state import FileProgressStateEnum
from simplyprint_ws_client.shared.asyncio.continuous_task import ContinuousTask
from simplyprint_ws_client.shared.asyncio.event_loop_provider import EventLoopProvider
from simplyprint_ws_client.shared.utils.backoff import Backoff, ConstantBackoff
from simplyprint_ws_client.shared.utils.stoppable import AsyncStoppable

Expand All @@ -16,11 +18,12 @@ class FakeFileProgress(AsyncStoppable):
backoff: Backoff
fake_progress_task: ContinuousTask

def __init__(self, client: Client, backoff: Optional[Backoff] = None, **kwargs) -> None:
def __init__(self, client: Client, backoff: Optional[Backoff] = None,
event_loop_provider: Optional[EventLoopProvider[asyncio.AbstractEventLoop]] = None, **kwargs) -> None:
super().__init__(**kwargs)
self.client = client
self.backoff = backoff or ConstantBackoff(5)
self.fake_progress_task = ContinuousTask(self.fake_progress)
self.fake_progress_task = ContinuousTask(self.fake_progress, provider=event_loop_provider)

@property
def is_downloading(self):
Expand All @@ -42,6 +45,5 @@ def tick(self):

async def fake_progress(self):
while not self.is_stopped() and self.is_downloading:
print("FAKE PROGRESS")
self.client.printer.file_progress.model_set_changed("state", "progress")
await self.wait(self.backoff.delay())

0 comments on commit 1edb26d

Please sign in to comment.