Skip to content

Commit

Permalink
Move the helpers to a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 26, 2024
1 parent 516d006 commit f33acd1
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 45 deletions.
44 changes: 44 additions & 0 deletions src/dispatch/any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import pickle
from typing import Any

import google.protobuf.any_pb2
import google.protobuf.message
import google.protobuf.wrappers_pb2
from google.protobuf import descriptor_pool, message_factory

from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb


def marshal_any(value: Any) -> google.protobuf.any_pb2.Any:
any = google.protobuf.any_pb2.Any()
if isinstance(value, google.protobuf.message.Message):
any.Pack(value)
else:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def unmarshal_any(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
any.Unpack(proto)
return proto
50 changes: 9 additions & 41 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tblib # type: ignore[import-untyped]
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.any import marshal_any, unmarshal_any
from dispatch.error import IncompatibleStateError, InvalidArgumentError
from dispatch.id import DispatchID
from dispatch.sdk.python.v1 import pickled_pb2 as pickled_pb
Expand Down Expand Up @@ -78,11 +79,11 @@ def __init__(self, req: function_pb.RunRequest):

self._has_input = req.HasField("input")
if self._has_input:
self._input = _any_unpickle(req.input)
self._input = unmarshal_any(req.input)
else:
if req.poll_result.coroutine_state:
raise IncompatibleStateError # coroutine_state is deprecated
self._coroutine_state = _any_unpickle(req.poll_result.typed_coroutine_state)
self._coroutine_state = unmarshal_any(req.poll_result.typed_coroutine_state)
self._call_results = [
CallResult._from_proto(r) for r in req.poll_result.results
]
Expand Down Expand Up @@ -141,7 +142,7 @@ def from_input_arguments(cls, function: str, *args, **kwargs):
return Input(
req=function_pb.RunRequest(
function=function,
input=_any_pickle(input),
input=marshal_any(input),
)
)

Expand All @@ -157,7 +158,7 @@ def from_poll_results(
req=function_pb.RunRequest(
function=function,
poll_result=poll_pb.PollResult(
typed_coroutine_state=_any_pickle(coroutine_state),
typed_coroutine_state=marshal_any(coroutine_state),
results=[result._as_proto() for result in call_results],
error=error._as_proto() if error else None,
),
Expand Down Expand Up @@ -241,7 +242,7 @@ def poll(
else None
)
poll = poll_pb.Poll(
typed_coroutine_state=_any_pickle(coroutine_state),
typed_coroutine_state=marshal_any(coroutine_state),
min_results=min_results,
max_results=max_results,
max_wait=max_wait,
Expand Down Expand Up @@ -279,7 +280,7 @@ class Call:
correlation_id: Optional[int] = None

def _as_proto(self) -> call_pb.Call:
input_bytes = _any_pickle(self.input)
input_bytes = marshal_any(self.input)
return call_pb.Call(
correlation_id=self.correlation_id,
endpoint=self.endpoint,
Expand All @@ -301,7 +302,7 @@ def _as_proto(self) -> call_pb.CallResult:
output_any = None
error_proto = None
if self.output is not None:
output_any = _any_pickle(self.output)
output_any = marshal_any(self.output)
if self.error is not None:
error_proto = self.error._as_proto()

Expand All @@ -317,7 +318,7 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult:
output = None
error = None
if proto.HasField("output"):
output = _any_unpickle(proto.output)
output = unmarshal_any(proto.output)
if proto.HasField("error"):
error = Error._from_proto(proto.error)

Expand Down Expand Up @@ -438,36 +439,3 @@ def _as_proto(self) -> error_pb.Error:
return error_pb.Error(
type=self.type, message=self.message, value=value, traceback=self.traceback
)


def _any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
any = google.protobuf.any_pb2.Any()
if isinstance(value, google.protobuf.message.Message):
any.Pack(value)
else:
p = pickled_pb.Pickled(pickled_value=pickle.dumps(value))
any.Pack(p, type_url_prefix="buf.build/stealthrocket/dispatch-proto/")
return any


def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
if any.Is(pickled_pb.Pickled.DESCRIPTOR):
p = pickled_pb.Pickled()
any.Unpack(p)
return pickle.loads(p.pickled_value)

elif any.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
b = google.protobuf.wrappers_pb2.BytesValue()
any.Unpack(b)
try:
# Assume it's the legacy container for pickled values.
return pickle.loads(b.value)
except Exception as e:
# Otherwise, return the literal bytes.
return b.value

pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(any.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
any.Unpack(proto)
return proto
6 changes: 3 additions & 3 deletions tests/dispatch/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import pytest

from dispatch.any import unmarshal_any
from dispatch.coroutine import AnyException, any, call, gather, race
from dispatch.experimental.durable import durable
from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.scheduler import (
AllFuture,
AnyFuture,
Expand Down Expand Up @@ -464,7 +464,7 @@ async def resume(
poll = assert_poll(prev_output)
input = Input.from_poll_results(
main.__qualname__,
any_unpickle(poll.typed_coroutine_state),
unmarshal_any(poll.typed_coroutine_state),
call_results,
Error.from_exception(poll_error) if poll_error else None,
)
Expand All @@ -489,7 +489,7 @@ def assert_exit_result_value(output: Output, expect: Any):
result = assert_exit_result(output)
assert result.HasField("output")
assert not result.HasField("error")
assert expect == any_unpickle(result.output)
assert expect == unmarshal_any(result.output)


def assert_exit_result_error(
Expand Down
1 change: 0 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from dispatch.experimental.durable.registry import clear_functions
from dispatch.fastapi import Dispatch
from dispatch.function import Arguments, Client, Error, Input, Output, Registry
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
Expand Down

0 comments on commit f33acd1

Please sign in to comment.