Skip to content

Commit

Permalink
Remove comparisons with string representation of structures in `tenso…
Browse files Browse the repository at this point in the history
…rflow_utils_test`.

Comparing with string representations makes the tests fragile and difficult to maintain; the string representation of `collections.OrderedDict` changes in Python 3.12 and this will make the tests robust to Python updates.

PiperOrigin-RevId: 723228315
  • Loading branch information
ZacharyGarrett authored and copybara-github committed Feb 4, 2025
1 parent a61b3cb commit a96a2d8
Showing 1 changed file with 49 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_compute_map_from_bindings_with_sequence(self):
sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='bar')
)
result = tensorflow_utils._compute_map_from_bindings(source, target)
self.assertEqual(str(result), "OrderedDict([('foo', 'bar')])")
self.assertEqual(result, collections.OrderedDict(foo='bar'))

def test_extract_tensor_names_from_binding_with_tuple_of_tensors(self):
with tf.Graph().as_default() as graph:
Expand Down Expand Up @@ -956,24 +956,23 @@ def test_make_empty_list_structure_for_element_type_spec_w_tuple_dict(self):
result = tensorflow_utils._make_empty_list_structure_for_element_type_spec(
type_spec
)
self.assertEqual(str(result), "([], OrderedDict([('a', []), ('b', [])]))")
self.assertEqual(result, ([], collections.OrderedDict(a=[], b=[])))

def test__append_to_list_structure_for_element_type_spec_w_tuple_dict(self):
def test_append_to_list_structure_for_element_type_spec_w_tuple_dict(self):
nested = tuple([[], collections.OrderedDict([('a', []), ('b', [])])])
type_spec = [tf.int32, [('a', tf.bool), ('b', tf.float32)]]
for value in [[10, {'a': True, 'b': 30}], (40, [False, 60])]:
tensorflow_utils._append_to_list_structure_for_element_type_spec(
nested, value, type_spec
)
self.assertEqual(
str(nested),
self.assertAllEqual(
nested,
(
'([<tf.Tensor: shape=(), dtype=int32, numpy=10>, <tf.Tensor:'
" shape=(), dtype=int32, numpy=40>], OrderedDict([('a',"
' [<tf.Tensor: shape=(), dtype=bool, numpy=True>, <tf.Tensor:'
" shape=(), dtype=bool, numpy=False>]), ('b', [<tf.Tensor:"
' shape=(), dtype=float32, numpy=30.0>, <tf.Tensor: shape=(),'
' dtype=float32, numpy=60.0>])]))'
[tf.constant(10, tf.int32), tf.constant(40, tf.int32)],
collections.OrderedDict(
a=[tf.constant(True, tf.bool), tf.constant(False, tf.bool)],
b=[tf.constant(30, tf.float32), tf.constant(60, tf.float32)],
),
),
)

Expand Down Expand Up @@ -1032,7 +1031,7 @@ def test__replace_empty_leaf_lists_with_numpy_arrays(self):
str(result).replace(' ', ''), str(expected_structure).replace(' ', '')
)

def _test_list_structure(self, type_spec, elements, expected_output_str):
def _test_list_structure(self, type_spec, elements, expected_output):
result = tensorflow_utils._make_empty_list_structure_for_element_type_spec(
type_spec
)
Expand All @@ -1043,85 +1042,85 @@ def _test_list_structure(self, type_spec, elements, expected_output_str):
result = tensorflow_utils._replace_empty_leaf_lists_with_numpy_arrays(
result, type_spec
)
self.assertEqual(
str(result).replace(' ', ''), expected_output_str.replace(' ', '')
)
# Use assertAllClose instead of allEqual for the behavior that empty
# arrays are equal.
self.assertAllClose(result, expected_output, rtol=0.0, atol=0.0)

def test_list_structures_from_element_type_spec_with_none_value(self):
self._test_list_structure(
[tf.int32, [('a', tf.bool), ('b', tf.float32)]],
[None],
str(
tuple([
np.array([], dtype=np.int32),
collections.OrderedDict([
('a', np.array([], dtype=bool)),
('b', np.array([], dtype=np.float32)),
]),
])
(
np.array([], dtype=np.int32),
collections.OrderedDict(
a=np.array([], dtype=bool),
b=np.array([], dtype=np.float32),
),
),
)

def test_list_structures_from_element_type_spec_with_int_value(self):
self._test_list_structure(
tf.int32, [1], '[<tf.Tensor:shape=(),dtype=int32,numpy=1>]'
)
self._test_list_structure(tf.int32, [1], [tf.constant(1, tf.int32)])

def test_list_structures_from_element_type_spec_with_empty_dict_value(self):
self._test_list_structure(
federated_language.StructType([]), [{}], 'OrderedDict()'
federated_language.StructType([]), [{}], collections.OrderedDict()
)

def test_list_structures_from_element_type_spec_with_dict_value(self):
self._test_list_structure(
[('a', tf.int32), ('b', tf.int32)],
[{'a': 1, 'b': 2}, {'a': 1, 'b': 2}],
(
"OrderedDict([('a',["
'<tf.Tensor:shape=(),dtype=int32,numpy=1>,'
'<tf.Tensor:shape=(),dtype=int32,numpy=1>'
"]),('b',["
'<tf.Tensor:shape=(),dtype=int32,numpy=2>,'
'<tf.Tensor:shape=(),dtype=int32,numpy=2>'
'])])'
collections.OrderedDict(
a=[
tf.constant(1, tf.int32),
tf.constant(1, tf.int32),
],
b=[
tf.constant(2, tf.int32),
tf.constant(2, tf.int32),
],
),
)

def test_list_structures_from_element_type_spec_with_no_values(self):
self._test_list_structure(tf.int32, [], '[]')
self._test_list_structure(tf.int32, [], [])

def test_list_structures_from_element_type_spec_with_int_values(self):
self._test_list_structure(
tf.int32,
[1, 2, 3],
(
'[<tf.Tensor:shape=(),dtype=int32,numpy=1>,'
'<tf.Tensor:shape=(),dtype=int32,numpy=2>,'
'<tf.Tensor:shape=(),dtype=int32,numpy=3>]'
),
[
tf.constant(1, tf.int32),
tf.constant(2, tf.int32),
tf.constant(3, tf.int32),
],
)

def test_list_structures_from_element_type_spec_with_empty_dict_values(self):
self._test_list_structure(
federated_language.StructType([]), [{}, {}, {}], 'OrderedDict()'
federated_language.StructType([]),
[{}, {}, {}],
collections.OrderedDict(),
)

def test_list_structures_from_element_type_spec_with_structures(self):
self._test_list_structure(
federated_language.StructType([('a', np.int32)]),
[structure.Struct([('a', 1)]), structure.Struct([('a', 2)])],
(
"OrderedDict([('a', ["
'<tf.Tensor:shape=(),dtype=int32,numpy=1>,'
'<tf.Tensor:shape=(),dtype=int32,numpy=2>])])'
collections.OrderedDict(
a=[
tf.constant(1, tf.int32),
tf.constant(2, tf.int32),
]
),
)

def test_list_structures_from_element_type_spec_with_empty_anon_tuples(self):
self._test_list_structure(
federated_language.StructType([]),
[structure.Struct([]), structure.Struct([])],
'OrderedDict()',
collections.OrderedDict(),
)

def test_list_structures_from_element_type_spec_w_list_of_anon_tuples(self):
Expand All @@ -1131,9 +1130,9 @@ def test_list_structures_from_element_type_spec_w_list_of_anon_tuples(self):
),
[[structure.Struct([('a', 1)])], [structure.Struct([('a', 2)])]],
(
"(OrderedDict([('a', ["
'<tf.Tensor:shape=(),dtype=int32,numpy=1>,'
'<tf.Tensor:shape=(),dtype=int32,numpy=2>])]),)'
collections.OrderedDict(
a=[tf.constant(1, tf.int32), tf.constant(2, tf.int32)],
),
),
)

Expand Down

0 comments on commit a96a2d8

Please sign in to comment.