Skip to content

Commit

Permalink
fix unset values vs empty values
Browse files Browse the repository at this point in the history
  • Loading branch information
deigen committed Mar 3, 2025
1 parent 32420ce commit 10e114d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
2 changes: 1 addition & 1 deletion clarifai/runners/models/model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _convert_output_to_proto(self, output: Any, variables_signature,
output = {f'return.{i}': item for i, item in enumerate(output)}
if not isinstance(output, dict): # TODO Output type, not just dict
output = {'return': output}
serialize(output, variables_signature, proto.data)
serialize(output, variables_signature, proto.data, is_output=True)
return proto

@classmethod
Expand Down
33 changes: 28 additions & 5 deletions clarifai/runners/utils/method_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def signatures_from_yaml(yaml_str):
return signatures_from_json(json.dumps(d))


def serialize(kwargs, signatures, proto=None):
def serialize(kwargs, signatures, proto=None, is_output=False):
'''
Serialize the given kwargs into the proto using the given signatures.
'''
Expand All @@ -168,19 +168,30 @@ def serialize(kwargs, signatures, proto=None):
raise TypeError(f'Missing required argument: {sig.name}')
continue # skip missing fields, they can be set to default on the server
data = kwargs[sig.name]
data_proto, field = _get_named_part(proto, sig.data_field, add_parts=True)
force_named_part = (_is_empty_proto_data(data) and not is_output and not sig.required)
data_proto, field = _get_data_part(
proto, sig, is_output=is_output, serializing=True, force_named_part=force_named_part)
serializer = get_serializer(sig.data_type)
serializer.serialize(data_proto, field, data)
return proto


def _is_empty_proto_data(data):
return isinstance(data, (str, bytes, int, float, bool, np.number)) and not data


def deserialize(proto, signatures, is_output=False):
'''
Deserialize the given proto into kwargs using the given signatures.
'''
kwargs = {}
for sig in signatures:
data_proto, field = _get_named_part(proto, sig.data_field, add_parts=False)
data_proto, field = _get_data_part(proto, sig, is_output=is_output, serializing=False)
if data_proto is None:
# not set in proto, check if required or skip if optional arg
if not is_output and sig.required:
raise ValueError(f'Missing required field: {sig.name}')
continue
serializer = get_serializer(sig.data_type)
data = serializer.deserialize(data_proto, field)
kwargs[sig.name] = data
Expand All @@ -203,7 +214,13 @@ def get_serializer(data_type: str) -> Serializer:
raise ValueError(f'Unsupported type: "{data_type}"')


def _get_named_part(proto, field, add_parts):
def _get_data_part(proto, sig, is_output, serializing, force_named_part=False):
field = sig.data_field

# check if we need to force a named part, to distinguish between empty and unset values
if force_named_part and not field.startswith('parts['):
field = f'parts[{sig.name}].{field}'

# gets the named part from the proto, according to the field path
# note we only support one level of named parts
parts = field.replace(' ', '').split('.')
Expand All @@ -212,6 +229,12 @@ def _get_named_part(proto, field, add_parts):
raise ValueError('Invalid field: %s' % field)

if len(parts) == 1:
# also need to check if there is an explicitly named part, e.g. for empty values
part = next((part for part in proto.parts if part.id == sig.name), None)
if part:
return part.data, field
if not serializing and not is_output and not getattr(proto, field):
return None, field
return proto, field

# list
Expand All @@ -228,7 +251,7 @@ def _get_named_part(proto, field, add_parts):
assert len(parts) in (2, 3) # parts[name].field, parts[name].parts[].field
part = next((part for part in proto.parts if part.id == name), None)
if part is None:
if not add_parts:
if not serializing:
raise ValueError('Missing part: %s' % name)
part = proto.parts.add()
part.id = name
Expand Down
6 changes: 5 additions & 1 deletion tests/runners/test_model_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def f(self, input: str) -> Output(x=int, y=str):
self.assertEqual(result.x, 3)
self.assertEqual(result.y, 'abc result')

def test_kwarg_defaults(self):
def test_kwarg_defaults_one_arg(self):

class MyModel(ModelClass):

Expand Down Expand Up @@ -613,6 +613,10 @@ def f(self, x: int = 5) -> int:
client = _get_servicer_client(MyModel())
result = client.f()
self.assertEqual(result, 6)
result = client.f(0)
self.assertEqual(result, 1)
result = client.f(-1)
self.assertEqual(result, 0)
result = client.f(10)
self.assertEqual(result, 11)
with self.assertRaises(TypeError):
Expand Down

0 comments on commit 10e114d

Please sign in to comment.