diff --git a/RELEASE.md b/RELEASE.md index d8f2ebfad8..903c3eb596 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -308,6 +308,10 @@ and this project adheres to `federated_language.framework.computation_contains` instead. * `tff.test.assert_types_equivalent`, use `federated_language.Type.is_equivalent_to` instead. +* `tff.program.NativeFederatedContext`, use + `federated_language.program.NativeFederatedContext` instead. +* `tff.program.NativeValueReference`, use + `federated_language.program.NativeValueReference` instead. ## Release 0.88.0 diff --git a/WORKSPACE b/WORKSPACE index 2fc4560182..db0b5f9880 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -90,6 +90,7 @@ http_archive( http_archive( name = "federated_language", patches = [ + "//third_party/federated_language:numpy.patch", "//third_party/federated_language:proto_library_loads.patch", "//third_party/federated_language:python_deps.patch", # Must come after `python_deps.patch`, this patches the output of `python_deps.patch`. @@ -98,9 +99,9 @@ http_archive( repo_mapping = { "@protobuf": "@com_google_protobuf", }, - sha256 = "d1500db4e6dc7403c3da121cf98ec03693b48384ec7cccae6fdc8ac3df182be0", - strip_prefix = "federated-language-be055feb0137577e0d27fe7c78aab87fbcb70b6d", - url = "https://github.com/google-parfait/federated-language/archive/be055feb0137577e0d27fe7c78aab87fbcb70b6d.tar.gz", + sha256 = "e2b13844d56233616d8ed664d15e155dbc6bb45743b6e5ce775a8553026b34a6", + strip_prefix = "federated-language-b685d2243891f9d7ca3c5820cfd690b4ecdb9697", + url = "https://github.com/google-parfait/federated-language/archive/b685d2243891f9d7ca3c5820cfd690b4ecdb9697.tar.gz", ) # The version of TensorFlow should match the version in diff --git a/examples/learning/federated_program/vizier/program.py b/examples/learning/federated_program/vizier/program.py index cc1630bc03..7eb7a9a612 100644 --- a/examples/learning/federated_program/vizier/program.py +++ b/examples/learning/federated_program/vizier/program.py @@ -96,7 +96,7 @@ def main(argv: Sequence[str]) -> None: raise app.UsageError('Too many command-line arguments.') context = tff.backends.native.create_async_local_cpp_execution_context() - context = tff.program.NativeFederatedContext(context) + context = federated_language.program.NativeFederatedContext(context) federated_language.framework.set_default_context(context) timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') diff --git a/examples/program/program.py b/examples/program/program.py index 174b9aa2c5..a38c952f90 100644 --- a/examples/program/program.py +++ b/examples/program/program.py @@ -82,7 +82,7 @@ def main(argv: Sequence[str]) -> None: # Create a context in which to execute the program logic. context = tff.backends.native.create_async_local_cpp_execution_context() - context = tff.program.NativeFederatedContext(context) + context = federated_language.program.NativeFederatedContext(context) federated_language.framework.set_default_context(context) # Create data sources that are compatible with the context and computations. diff --git a/examples/program/program_logic_test.py b/examples/program/program_logic_test.py index 6b73050d02..fc08db58e4 100644 --- a/examples/program/program_logic_test.py +++ b/examples/program/program_logic_test.py @@ -31,12 +31,14 @@ def _create_native_federated_context(): context = tff.backends.native.create_async_local_cpp_execution_context() - return tff.program.NativeFederatedContext(context) + return federated_language.program.NativeFederatedContext(context) def _create_mock_context() -> mock.Mock: return mock.create_autospec( - tff.program.NativeFederatedContext, spec_set=True, instance=True + federated_language.program.NativeFederatedContext, + spec_set=True, + instance=True, ) diff --git a/tensorflow_federated/python/aggregators/distributed_dp_test.py b/tensorflow_federated/python/aggregators/distributed_dp_test.py index 3d8e827a62..f65ffc42c3 100644 --- a/tensorflow_federated/python/aggregators/distributed_dp_test.py +++ b/tensorflow_federated/python/aggregators/distributed_dp_test.py @@ -189,8 +189,10 @@ def test_type_properties(self, value_type, mechanism): ) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) @parameterized.named_parameters( diff --git a/tensorflow_federated/python/aggregators/primitives_test.py b/tensorflow_federated/python/aggregators/primitives_test.py index 89c91951ad..39b4959796 100644 --- a/tensorflow_federated/python/aggregators/primitives_test.py +++ b/tensorflow_federated/python/aggregators/primitives_test.py @@ -182,8 +182,10 @@ def comp_py_bounds(value): np.array(1.0, dtype), ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - comp_py_bounds + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + comp_py_bounds + ) ) # Bounds provided as tff values. @@ -197,12 +199,12 @@ def comp_py_bounds(value): def comp_tff_bounds(value, upper_bound, lower_bound): return primitives.secure_quantized_sum(value, upper_bound, lower_bound) - try: - federated_language.framework.assert_not_contains_unsecure_aggregation( - comp_tff_bounds - ) - except AssertionError: - self.fail('Computation contains non-secure aggregation.') + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + comp_tff_bounds + ), + 'Computation contains non-secure aggregation.', + ) class SecureQuantizedSumTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_federated/python/aggregators/quantile_estimation_test.py b/tensorflow_federated/python/aggregators/quantile_estimation_test.py index 1d65c47196..b609db72e7 100644 --- a/tensorflow_federated/python/aggregators/quantile_estimation_test.py +++ b/tensorflow_federated/python/aggregators/quantile_estimation_test.py @@ -216,8 +216,10 @@ def test_secure_estimation_true_only_contains_secure_aggregation(self): learning_rate=1.0, secure_estimation=True, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - secure_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + secure_process.next + ) ) diff --git a/tensorflow_federated/python/aggregators/secure_test.py b/tensorflow_federated/python/aggregators/secure_test.py index 5e6432bdb1..a1ceac36b3 100644 --- a/tensorflow_federated/python/aggregators/secure_test.py +++ b/tensorflow_federated/python/aggregators/secure_test.py @@ -145,8 +145,10 @@ def test_type_properties(self, modulus, value_type, symmetric_range): self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) def test_float_modulus_raises(self): @@ -321,8 +323,10 @@ def test_type_properties_constant_bounds( self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) @parameterized.named_parameters( @@ -374,8 +378,10 @@ def test_type_properties_single_bound(self, value_type, dtype): self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) @parameterized.named_parameters( @@ -430,8 +436,10 @@ def test_type_properties_adaptive_bounds(self, value_type, dtype): self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) @parameterized.named_parameters( diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py index 677b5eeb1c..00c8796bda 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py @@ -1064,8 +1064,10 @@ def test_secure_sum(self, dp_mechanism): dp_mechanism=dp_mechanism, enable_secure_sum=True, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - hihi_computation + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + hihi_computation + ) ) @mock.patch('tensorflow.timestamp') diff --git a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py index 565bea60c7..a95911dd80 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py @@ -228,10 +228,10 @@ def _aggregation_predicate( # computation.protos; to avoid opening up a visibility hole that isn't # technically necessary here, we prefer to simply skip the static check here # for computations which cannot convert themselves to building blocks. - if hasattr( - after_merge, 'to_building_block' + if isinstance( + after_merge, federated_language.framework.ConcreteComputation ) and federated_language.framework.computation_contains( - after_merge.to_building_block(), _aggregation_predicate + after_merge, _aggregation_predicate ): formatted_aggregations = ', '.join( '{}: {}'.format(elem[0], elem[1]) for elem in aggregations diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_test.py b/tensorflow_federated/python/learning/algorithms/fed_avg_test.py index abac087885..22f40cdb88 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_test.py @@ -135,8 +135,10 @@ def test_weighted_fed_avg_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) def test_unweighted_fed_avg_with_only_secure_aggregation(self): @@ -149,8 +151,10 @@ def test_unweighted_fed_avg_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) @@ -182,8 +186,10 @@ def test_weighted_fed_avg_with_only_secure_aggregation(self, constructor): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py index 69d63dbff9..4b15600e60 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py @@ -154,8 +154,10 @@ def test_construction_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) def test_measurements_include_client_learning_rate(self): diff --git a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py index e2dda7a62a..3a7d074ffb 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py @@ -171,8 +171,10 @@ def test_weighted_fed_prox_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) def test_unweighted_fed_prox_with_only_secure_aggregation(self): @@ -186,8 +188,10 @@ def test_unweighted_fed_prox_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py index 0eb99a83c2..f1b4fe19bd 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py @@ -186,8 +186,10 @@ def test_no_unsecure_aggregation_with_secure_aggregator(self): model_aggregator=model_update_aggregator.secure_aggregator(), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) @@ -287,8 +289,10 @@ def test_no_unsecure_aggregation_with_secure_aggregator(self): model_aggregator=model_update_aggregator.secure_aggregator(), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) diff --git a/tensorflow_federated/python/learning/algorithms/mime_test.py b/tensorflow_federated/python/learning/algorithms/mime_test.py index ec4cea97a3..c6f31dd06b 100644 --- a/tensorflow_federated/python/learning/algorithms/mime_test.py +++ b/tensorflow_federated/python/learning/algorithms/mime_test.py @@ -479,8 +479,10 @@ def test_weighted_mime_lite_with_only_secure_aggregation(self): full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) def test_unweighted_mime_lite_with_only_secure_aggregation(self): @@ -492,8 +494,10 @@ def test_unweighted_mime_lite_with_only_secure_aggregation(self): full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize, ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - learning_process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + learning_process.next + ) ) @tensorflow_test_utils.skip_test_for_multi_gpu diff --git a/tensorflow_federated/python/learning/metrics/aggregator_test.py b/tensorflow_federated/python/learning/metrics/aggregator_test.py index e9be3ece95..b65e9b81b7 100644 --- a/tensorflow_federated/python/learning/metrics/aggregator_test.py +++ b/tensorflow_federated/python/learning/metrics/aggregator_test.py @@ -354,8 +354,10 @@ def test_default_value_ranges_returns_correct_results( def aggregator_computation(unfinalized_metrics): return polymorphic_aggregator_computation(unfinalized_metrics) - federated_language.framework.assert_not_contains_unsecure_aggregation( - aggregator_computation + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + aggregator_computation + ) ) aggregated_metrics = aggregator_computation( diff --git a/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py b/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py index 16fe7e5285..e357a2a49a 100644 --- a/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py +++ b/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py @@ -264,8 +264,10 @@ def test_type_properties_with_inner_secure_sum_process( self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) @parameterized.named_parameters( @@ -518,8 +520,10 @@ def test_secure_sum_then_finalize_metrics(self): client_data = [local_unfinalized_metrics, local_unfinalized_metrics] output = process.next(state, client_data) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) _, unfinalized_metrics_accumulators = output.state @@ -600,8 +604,10 @@ def test_default_value_ranges_returns_correct_results(self): collections.OrderedDict, ) process = aggregate_factory.create(local_unfinalized_metrics_type) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) state = process.initialize() @@ -688,8 +694,10 @@ def test_user_value_ranges_returns_correct_results(self): metric_value_ranges ) process = aggregate_factory.create(local_unfinalized_metrics_type) - federated_language.framework.assert_not_contains_unsecure_aggregation( - process.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + process.next + ) ) state = process.initialize() diff --git a/tensorflow_federated/python/learning/model_update_aggregator_test.py b/tensorflow_federated/python/learning/model_update_aggregator_test.py index 6ec0d55aa4..113a3ea7df 100644 --- a/tensorflow_federated/python/learning/model_update_aggregator_test.py +++ b/tensorflow_federated/python/learning/model_update_aggregator_test.py @@ -191,24 +191,30 @@ def test_weighted_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator( weighted=True ).create(_FLOAT_MATRIX_TYPE, _FLOAT_TYPE) - federated_language.framework.assert_not_contains_unsecure_aggregation( - aggregator.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + aggregator.next + ) ) def test_unweighted_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator( weighted=False ).create(_FLOAT_MATRIX_TYPE) - federated_language.framework.assert_not_contains_unsecure_aggregation( - aggregator.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + aggregator.next + ) ) def test_ddp_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.ddp_secure_aggregator( noise_multiplier=1e-2, expected_clients_per_round=10 ).create(_FLOAT_MATRIX_TYPE) - federated_language.framework.assert_not_contains_unsecure_aggregation( - aggregator.next + self.assertFalse( + federated_language.framework.computation_contains_unsecure_aggregation( + aggregator.next + ) ) @parameterized.named_parameters( diff --git a/tensorflow_federated/python/learning/programs/BUILD b/tensorflow_federated/python/learning/programs/BUILD index 0369b43e75..550b35dcf6 100644 --- a/tensorflow_federated/python/learning/programs/BUILD +++ b/tensorflow_federated/python/learning/programs/BUILD @@ -46,7 +46,6 @@ py_test( "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", "//tensorflow_federated/python/program:file_program_state_manager", - "//tensorflow_federated/python/program:native_platform", "@federated_language//federated_language", ], ) @@ -84,7 +83,6 @@ py_test( "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:native_platform", "@federated_language//federated_language", ], ) @@ -109,7 +107,6 @@ py_test( ":vizier_program_logic", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:native_platform", "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py b/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py index 9973da17c9..dc53850d1d 100644 --- a/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py @@ -30,7 +30,6 @@ from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process from tensorflow_federated.python.program import file_program_state_manager -from tensorflow_federated.python.program import native_platform # Convenience aliases. TensorType = federated_language.TensorType @@ -108,7 +107,7 @@ async def _value(): coro = _value() task = asyncio.create_task(coro) - return native_platform.NativeValueReference(task, value_type) + return federated_language.program.NativeValueReference(task, value_type) test_value = collections.OrderedDict( a=awaitable_value('foo', federated_language.TensorType(np.str_)), @@ -237,7 +236,7 @@ def test_get_model_ids_for_multi_model_eval_no_model_ids_returns_none(self): def _create_test_context() -> federated_language.program.FederatedContext: - return native_platform.NativeFederatedContext( + return federated_language.program.NativeFederatedContext( execution_contexts.create_async_local_cpp_execution_context() ) diff --git a/tensorflow_federated/python/learning/programs/training_program_logic_test.py b/tensorflow_federated/python/learning/programs/training_program_logic_test.py index 03f76d6e22..ff8405770c 100644 --- a/tensorflow_federated/python/learning/programs/training_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/training_program_logic_test.py @@ -29,7 +29,6 @@ from tensorflow_federated.python.learning.programs import training_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import native_platform # Convenience aliases. ProgramState = training_program_logic.ProgramState @@ -82,7 +81,7 @@ async def task(): def _create_test_context() -> federated_language.program.FederatedContext: - return native_platform.NativeFederatedContext( + return federated_language.program.NativeFederatedContext( execution_contexts.create_async_local_cpp_execution_context() ) diff --git a/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py b/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py index e27626c87c..c39824df46 100644 --- a/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py @@ -25,12 +25,13 @@ from tensorflow_federated.python.learning.programs import vizier_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import native_platform def _create_mock_context() -> mock.Mock: return mock.create_autospec( - native_platform.NativeFederatedContext, spec_set=True, instance=True + federated_language.program.NativeFederatedContext, + spec_set=True, + instance=True, ) diff --git a/tensorflow_federated/python/program/BUILD b/tensorflow_federated/python/program/BUILD index a77dc5d57d..88f902b690 100644 --- a/tensorflow_federated/python/program/BUILD +++ b/tensorflow_federated/python/program/BUILD @@ -26,7 +26,6 @@ py_library( ":dataset_data_source", ":file_program_state_manager", ":file_release_manager", - ":native_platform", ":tensorboard_release_manager", ], ) @@ -123,29 +122,6 @@ py_test( deps = [":file_utils"], ) -py_library( - name = "native_platform", - srcs = ["native_platform.py"], - deps = [ - ":structure_utils", - "//tensorflow_federated/python/common_libs:structure", - "@federated_language//federated_language", - ], -) - -py_test( - name = "native_platform_test", - srcs = ["native_platform_test.py"], - deps = [ - ":native_platform", - ":program_test_utils", - ":structure_utils", - "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "@federated_language//federated_language", - ], -) - py_library( name = "program_test_utils", testonly = True, diff --git a/tensorflow_federated/python/program/__init__.py b/tensorflow_federated/python/program/__init__.py index 8eb24ecefd..badc4046f6 100644 --- a/tensorflow_federated/python/program/__init__.py +++ b/tensorflow_federated/python/program/__init__.py @@ -23,7 +23,5 @@ from tensorflow_federated.python.program.file_release_manager import CSVKeyFieldnameNotFoundError from tensorflow_federated.python.program.file_release_manager import CSVSaveMode from tensorflow_federated.python.program.file_release_manager import SavedModelFileReleaseManager -from tensorflow_federated.python.program.native_platform import NativeFederatedContext -from tensorflow_federated.python.program.native_platform import NativeValueReference from tensorflow_federated.python.program.tensorboard_release_manager import TensorBoardReleaseManager # pylint: enable=g-importing-member diff --git a/tensorflow_federated/python/program/native_platform.py b/tensorflow_federated/python/program/native_platform.py deleted file mode 100644 index 74dd74e966..0000000000 --- a/tensorflow_federated/python/program/native_platform.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2022, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A federated platform implemented using native TFF components.""" - -import asyncio -from collections.abc import Mapping -from typing import Optional, Union - -import federated_language -import tree - -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.program import structure_utils - - -class NativeValueReference( - federated_language.program.MaterializableValueReference -): - """A `federated_language.program.MaterializableValueReference` backed by a task.""" - - def __init__( - self, - task: asyncio.Task, - type_signature: federated_language.program.MaterializableTypeSignature, - ): - """Returns an initialized `tff.program.NativeValueReference`. - - Args: - task: An `asyncio.Task` to run. - type_signature: The `federated_language.Type` of this object. - """ - self._task = task - self._type_signature = type_signature - - @property - def type_signature( - self, - ) -> federated_language.program.MaterializableTypeSignature: - """The `federated_language.TensorType` of this object.""" - return self._type_signature - - async def get_value(self) -> federated_language.program.MaterializedValue: - """Returns the referenced value as a numpy scalar or array.""" - return await self._task - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, NativeValueReference): - return NotImplemented - return ( - self._type_signature, - self._task, - ) == ( - other._type_signature, - other._task, - ) - - -def _create_structure_of_references( - task: asyncio.Task, - type_signature: federated_language.Type, -) -> structure_utils.Structure[NativeValueReference]: - """Returns a structure of `tff.program.NativeValueReference`s. - - Args: - task: A task used to create the structure of - `tff.program.NativeValueReference`s. - type_signature: The `federated_language.Type` of the value returned by - `task`; must contain only structures, server-placed values, or tensors. - - Raises: - NotImplementedError: If `type_signature` contains an unexpected type. - """ - if isinstance(type_signature, federated_language.StructType): - - def _get_container_cls( - type_spec: federated_language.StructType, - ) -> type[object]: - container_cls = type_spec.python_container - if container_cls is None: - has_names = [name is not None for name, _ in type_spec.items()] - if any(has_names): - if not all(has_names): - raise ValueError( - 'Expected `type_spec` to have either all named or unnamed' - f' elements, found {type_spec}.' - ) - container_cls = dict - else: - container_cls = list - return container_cls - - async def _get_item( - task: asyncio.Task, key: Union[str, int] - ) -> federated_language.program.MaterializedValue: - value = await task - return value[key] - - elements = [] - for index, (name, element_type) in enumerate(type_signature.items()): - container_cls = _get_container_cls(type_signature) - if issubclass(container_cls, Mapping): - key = name - else: - key = index - element = _get_item(task, key) - element_task = asyncio.create_task(element) - element = _create_structure_of_references(element_task, element_type) - elements.append(element) - return federated_language.framework.to_structure_with_type( - elements, type_signature - ) - elif ( - isinstance(type_signature, federated_language.FederatedType) - and type_signature.placement == federated_language.SERVER - ): - return _create_structure_of_references(task, type_signature.member) - elif isinstance(type_signature, federated_language.SequenceType): - return NativeValueReference(task, type_signature) - elif isinstance(type_signature, federated_language.TensorType): - return NativeValueReference(task, type_signature) - else: - raise NotImplementedError(f'Unexpected type found: {type_signature}.') - - -class NativeFederatedContext(federated_language.program.FederatedContext): - """A `federated_language.program.FederatedContext` backed by an execution context.""" - - def __init__( - self, context: federated_language.framework.AsyncExecutionContext - ): - """Returns an initialized `tff.program.NativeFederatedContext`. - - Args: - context: An `federated_language.framework.AsyncExecutionContext`. - """ - self._context = context - - def invoke( - self, - comp: federated_language.framework.Computation, - arg: Optional[federated_language.program.ComputationArg], - ) -> structure_utils.Structure[NativeValueReference]: - """Invokes the `comp` with the argument `arg`. - - Args: - comp: The `federated_language.Computation` being invoked. - arg: The optional argument of `comp`; server-placed values must be - represented by `federated_language.program.MaterializableStructure`, and - client-placed values must be represented by structures of values - returned by a `federated_language.program.FederatedDataSourceIterator`. - - Returns: - The result of invocation; a structure of - `federated_language.program.MaterializableValueReference`. - - Raises: - ValueError: If the result type of the invoked computation does not contain - only structures, server-placed values, or tensors. - Raises: - ValueError: If the result type of `comp` does not contain only structures, - server-placed values, or tensors. - """ - result_type = comp.type_signature.result - if not federated_language.program.contains_only_server_placed_data( - result_type - ): - raise ValueError( - 'Expected the result type of `comp` to contain only structures, ' - f'server-placed values, or tensors, found {result_type}.' - ) - - async def _invoke( - context: federated_language.framework.AsyncExecutionContext, - comp: federated_language.framework.Computation, - arg: federated_language.program.MaterializableStructure, - ) -> federated_language.program.MaterializedStructure: - if comp.type_signature.parameter is not None: - - def _to_python(obj): - if isinstance(obj, structure.Struct): - return structure.to_odict_or_tuple(obj) - else: - return None - - arg = tree.traverse(_to_python, arg) - arg = await federated_language.program.materialize_value(arg) - - return await context.invoke(comp, arg) - - coro = _invoke(self._context, comp, arg) - task = asyncio.create_task(coro) - return _create_structure_of_references(task, result_type) diff --git a/tensorflow_federated/python/program/native_platform_test.py b/tensorflow_federated/python/program/native_platform_test.py deleted file mode 100644 index 2afaf21f26..0000000000 --- a/tensorflow_federated/python/program/native_platform_test.py +++ /dev/null @@ -1,656 +0,0 @@ -# Copyright 2022, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import unittest -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import federated_language -import numpy as np -import tree - -from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.program import native_platform -from tensorflow_federated.python.program import program_test_utils -from tensorflow_federated.python.program import structure_utils - - -def _create_task(value: object) -> object: - - async def _fn(value: object) -> object: - return value - - coro = _fn(value) - return asyncio.create_task(coro) - - -def _create_identity_federated_computation( - type_signature: federated_language.Type, -) -> federated_language.framework.Computation: - @federated_language.federated_computation(type_signature) - def _identity(value: object) -> object: - return value - - return _identity - - -def _create_identity_tensorflow_computation( - type_signature: federated_language.Type, -) -> federated_language.framework.Computation: - @tensorflow_computation.tf_computation(type_signature) - def _identity(value: object) -> object: - return value - - return _identity - - -class NativeValueReferenceTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - ( - 'tensor_bool', - lambda: _create_task(True), - federated_language.TensorType(np.bool_), - True, - ), - ( - 'tensor_int', - lambda: _create_task(1), - federated_language.TensorType(np.int32), - 1, - ), - ( - 'tensor_str', - lambda: _create_task('abc'), - federated_language.TensorType(np.str_), - 'abc', - ), - ( - 'sequence', - lambda: _create_task([1, 2, 3]), - federated_language.SequenceType(np.int32), - [1, 2, 3], - ), - ) - async def test_get_value_returns_value( - self, task_factory, type_signature, expected_value - ): - task = task_factory() - reference = native_platform.NativeValueReference(task, type_signature) - - actual_value = await reference.get_value() - - tree.assert_same_structure(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - - -class CreateStructureOfReferencesTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - ( - 'tensor', - lambda: _create_task(1), - federated_language.TensorType(np.int32), - lambda: native_platform.NativeValueReference( - _create_task(1), federated_language.TensorType(np.int32) - ), - ), - ( - 'sequence', - lambda: _create_task([1, 2, 3]), - federated_language.SequenceType(np.int32), - lambda: native_platform.NativeValueReference( - _create_task([1, 2, 3]), - federated_language.SequenceType(np.int32), - ), - ), - ( - 'federated_server', - lambda: _create_task(1), - federated_language.FederatedType(np.int32, federated_language.SERVER), - lambda: native_platform.NativeValueReference( - _create_task(1), federated_language.TensorType(np.int32) - ), - ), - ( - 'struct_unnamed', - lambda: _create_task([True, 1, 'abc']), - federated_language.StructWithPythonType( - [np.bool_, np.int32, np.str_], list - ), - lambda: [ - native_platform.NativeValueReference( - _create_task(True), federated_language.TensorType(np.bool_) - ), - native_platform.NativeValueReference( - _create_task(1), federated_language.TensorType(np.int32) - ), - native_platform.NativeValueReference( - _create_task('abc'), federated_language.TensorType(np.str_) - ), - ], - ), - ( - 'struct_named', - lambda: _create_task({'a': True, 'b': 1, 'c': 'abc'}), - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ('c', np.str_), - ], - dict, - ), - lambda: { - 'a': native_platform.NativeValueReference( - _create_task(True), federated_language.TensorType(np.bool_) - ), - 'b': native_platform.NativeValueReference( - _create_task(1), federated_language.TensorType(np.int32) - ), - 'c': native_platform.NativeValueReference( - _create_task('abc'), federated_language.TensorType(np.str_) - ), - }, - ), - ( - 'struct_nested', - lambda: _create_task({'x': {'a': True, 'b': 1}, 'y': {'c': 'abc'}}), - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ( - 'y', - federated_language.StructWithPythonType( - [ - ('c', np.str_), - ], - dict, - ), - ), - ], - dict, - ), - lambda: { - 'x': { - 'a': native_platform.NativeValueReference( - _create_task(True), - federated_language.TensorType(np.bool_), - ), - 'b': native_platform.NativeValueReference( - _create_task(1), - federated_language.TensorType(np.int32), - ), - }, - 'y': { - 'c': native_platform.NativeValueReference( - _create_task('abc'), - federated_language.TensorType(np.str_), - ), - }, - }, - ), - ) - async def test_returns_value( - self, task_factory, type_signature, expected_value_factory - ): - task = task_factory() - actual_value = native_platform._create_structure_of_references( - task, type_signature - ) - - expected_value = expected_value_factory() - actual_value = await federated_language.program.materialize_value( - actual_value - ) - expected_value = await federated_language.program.materialize_value( - expected_value - ) - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'federated_clients', - federated_language.FederatedType( - np.int32, federated_language.CLIENTS - ), - ), - ('function', federated_language.FunctionType(np.int32, np.int32)), - ('placement', federated_language.PlacementType()), - ) - async def test_raises_not_implemented_error_with_type_signature( - self, type_signature - ): - task = _create_task(1) - - with self.assertRaises(NotImplementedError): - native_platform._create_structure_of_references(task, type_signature) - - -class NativeFederatedContextTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - ( - 'tensor', - _create_identity_federated_computation( - federated_language.TensorType(np.int32) - ), - 1, - 1, - ), - ( - 'sequence', - _create_identity_tensorflow_computation( - federated_language.SequenceType(np.int32) - ), - [1, 2, 3], - [1, 2, 3], - ), - ( - 'federated_server', - _create_identity_federated_computation( - federated_language.FederatedType( - np.int32, federated_language.SERVER - ) - ), - 1, - 1, - ), - ( - 'struct_unnamed', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [np.bool_, np.int32, np.str_], list - ) - ), - [True, 1, 'abc'], - [True, 1, b'abc'], - ), - ( - 'struct_named_ordered', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ('c', np.str_), - ], - dict, - ) - ), - {'a': True, 'b': 1, 'c': 'abc'}, - {'a': True, 'b': 1, 'c': b'abc'}, - ), - ( - 'struct_named_unordered', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ('c', np.str_), - ('b', np.int32), - ('a', np.bool_), - ], - dict, - ) - ), - {'c': 'abc', 'b': 1, 'a': True}, - {'c': b'abc', 'b': 1, 'a': True}, - ), - ) - async def test_invoke_returns_result(self, comp, arg, expected_value): - context = execution_contexts.create_async_local_cpp_execution_context() - context = native_platform.NativeFederatedContext(context) - - with program_test_utils.assert_not_warns(RuntimeWarning): - result = context.invoke(comp, arg) - actual_value = await federated_language.program.materialize_value(result) - - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'struct_nested', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ( - 'y', - federated_language.StructWithPythonType( - [ - ('c', np.str_), - ], - dict, - ), - ), - ], - dict, - ) - ), - {'x': {'a': True, 'b': 1}, 'y': {'c': 'abc'}}, - {'x': {'a': True, 'b': 1}, 'y': {'c': b'abc'}}, - ), - ( - 'struct_partially_empty', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ('y', federated_language.StructWithPythonType([], dict)), - ], - dict, - ) - ), - {'x': {'a': True, 'b': 1}, 'y': {}}, - {'x': {'a': True, 'b': 1}, 'y': {}}, - ), - ) - async def test_invoke_returns_result_materialized_sequentially( - self, comp, arg, expected_value - ): - context = execution_contexts.create_async_local_cpp_execution_context() - mock_context = mock.Mock( - spec=federated_language.framework.AsyncExecutionContext, wraps=context - ) - context = native_platform.NativeFederatedContext(mock_context) - - with program_test_utils.assert_not_warns(RuntimeWarning): - result = context.invoke(comp, arg) - flattened = structure_utils.flatten(result) - materialized = [await v.get_value() for v in flattened] - actual_value = structure_utils.unflatten_as(result, materialized) - - self.assertEqual(actual_value, expected_value) - mock_context.invoke.assert_called_once() - - @parameterized.named_parameters( - ( - 'struct_nested', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ( - 'y', - federated_language.StructWithPythonType( - [ - ('c', np.str_), - ], - dict, - ), - ), - ], - dict, - ) - ), - {'x': {'a': True, 'b': 1}, 'y': {'c': 'abc'}}, - {'x': {'a': True, 'b': 1}, 'y': {'c': b'abc'}}, - ), - ( - 'struct_partially_empty', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ('y', federated_language.StructWithPythonType([], dict)), - ], - dict, - ) - ), - {'x': {'a': True, 'b': 1}, 'y': {}}, - {'x': {'a': True, 'b': 1}, 'y': {}}, - ), - ) - async def test_invoke_returns_result_materialized_concurrently( - self, comp, arg, expected_value - ): - context = execution_contexts.create_async_local_cpp_execution_context() - mock_context = mock.Mock( - spec=federated_language.framework.AsyncExecutionContext, wraps=context - ) - context = native_platform.NativeFederatedContext(mock_context) - - with program_test_utils.assert_not_warns(RuntimeWarning): - result = context.invoke(comp, arg) - actual_value = await federated_language.program.materialize_value(result) - - self.assertEqual(actual_value, expected_value) - mock_context.invoke.assert_called_once() - - @parameterized.named_parameters( - ( - 'struct_nested', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ( - 'y', - federated_language.StructWithPythonType( - [ - ('c', np.str_), - ], - dict, - ), - ), - ], - dict, - ) - ), - {'x': {'a': True, 'b': 1}, 'y': {'c': 'abc'}}, - {'x': {'a': True, 'b': 1}, 'y': {'c': b'abc'}}, - ), - ( - 'struct_partially_empty', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ( - 'x', - federated_language.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - dict, - ), - ), - ('y', federated_language.StructWithPythonType([], dict)), - ], - dict, - ) - ), - {'x': {'a': True, 'b': 1}, 'y': {}}, - {'x': {'a': True, 'b': 1}, 'y': {}}, - ), - ) - async def test_invoke_returns_result_materialized_multiple( - self, comp, arg, expected_value - ): - context = execution_contexts.create_async_local_cpp_execution_context() - mock_context = mock.Mock( - spec=federated_language.framework.AsyncExecutionContext, wraps=context - ) - context = native_platform.NativeFederatedContext(mock_context) - - with program_test_utils.assert_not_warns(RuntimeWarning): - result = context.invoke(comp, arg) - actual_value = await asyncio.gather( - federated_language.program.materialize_value(result), - federated_language.program.materialize_value(result), - federated_language.program.materialize_value(result), - ) - - expected_value = [expected_value] * 3 - self.assertEqual(actual_value, expected_value) - mock_context.invoke.assert_called_once() - - @parameterized.named_parameters( - ( - 'struct_unnamed_empty', - _create_identity_federated_computation( - federated_language.StructWithPythonType([], list) - ), - [], - [], - ), - ( - 'struct_named_empty', - _create_identity_federated_computation( - federated_language.StructWithPythonType([], dict) - ), - {}, - {}, - ), - ( - 'struct_nested_empty', - _create_identity_federated_computation( - federated_language.StructWithPythonType( - [ - ('x', federated_language.StructWithPythonType([], dict)), - ('y', federated_language.StructWithPythonType([], dict)), - ], - dict, - ) - ), - {'x': {}, 'y': {}}, - {'x': {}, 'y': {}}, - ), - ) - async def test_invoke_returns_result_comp_not_called( - self, comp, arg, expected_value - ): - context = execution_contexts.create_async_local_cpp_execution_context() - mock_context = mock.Mock( - spec=federated_language.framework.AsyncExecutionContext, wraps=context - ) - context = native_platform.NativeFederatedContext(mock_context) - - with program_test_utils.assert_not_warns(RuntimeWarning): - result = context.invoke(comp, arg) - actual_value = await federated_language.program.materialize_value(result) - - self.assertEqual(actual_value, expected_value) - mock_context.invoke.assert_not_called() - - @parameterized.named_parameters( - ( - 'federated_clients', - _create_identity_federated_computation( - federated_language.FederatedType( - np.int32, federated_language.CLIENTS - ) - ), - 1, - ), - ( - 'function', - _create_identity_federated_computation( - federated_language.FunctionType(np.int32, np.int32) - ), - _create_identity_federated_computation( - federated_language.TensorType(np.int32) - ), - ), - ( - 'placement', - _create_identity_federated_computation( - federated_language.PlacementType() - ), - None, - ), - ) - def test_invoke_raises_value_error_with_comp(self, comp, arg): - context = execution_contexts.create_async_local_cpp_execution_context() - context = native_platform.NativeFederatedContext(context) - - with self.assertRaises(ValueError): - context.invoke(comp, arg) - - -if __name__ == '__main__': - absltest.main() diff --git a/third_party/federated_language/numpy.patch b/third_party/federated_language/numpy.patch new file mode 100644 index 0000000000..2cde5a344c --- /dev/null +++ b/third_party/federated_language/numpy.patch @@ -0,0 +1,35 @@ +diff --git federated_language/types/dtype_utils.py federated_language/types/dtype_utils.py +index 59f8937..342fbee 100644 +--- federated_language/types/dtype_utils.py ++++ federated_language/types/dtype_utils.py +@@ -102,15 +102,21 @@ def can_cast( + dtype: The dtype to check against. + """ + +- # When encountering an overflow, numpy issues a `RuntimeWarning` for floating +- # dtypes and raises an `OverflowError` for integer dtypes. +- with warnings.catch_warnings(): +- warnings.simplefilter(action='error', category=RuntimeWarning) +- try: +- np.asarray(value, dtype=dtype) +- return True +- except (OverflowError, RuntimeWarning): +- return False ++ # `np.can_cast` does not support Python scalars (since version 2.0). Casting ++ # the value to a numpy value and testing for an overflow is equivalent to ++ # testing the Python value. ++ numpy_version = tuple(int(x) for x in np.__version__.split('.')) ++ if numpy_version >= (2, 0): ++ # When encountering an overflow, numpy issues a `RuntimeWarning` for ++ # floating dtypes and raises an `OverflowError` for integer dtypes. ++ with warnings.catch_warnings(action='error', category=RuntimeWarning): ++ try: ++ np.asarray(value, dtype=dtype) ++ return True ++ except (OverflowError, RuntimeWarning): ++ return False ++ else: ++ return np.can_cast(value, dtype) + + + def infer_dtype( diff --git a/third_party/federated_language/python_deps.patch b/third_party/federated_language/python_deps.patch index 2b76ba637a..3c5cedce8c 100644 --- a/third_party/federated_language/python_deps.patch +++ b/third_party/federated_language/python_deps.patch @@ -1,5 +1,5 @@ diff --git federated_language/common_libs/BUILD federated_language/common_libs/BUILD -index 7e91f23..08aa419 100644 +index b7d3dda..08aa419 100644 --- federated_language/common_libs/BUILD +++ federated_language/common_libs/BUILD @@ -37,7 +37,6 @@ py_library( @@ -10,7 +10,7 @@ index 7e91f23..08aa419 100644 ) py_test( -@@ -57,29 +56,19 @@ py_test( +@@ -57,26 +56,19 @@ py_test( "golden_test_goldens/test_check_string_succeeds.expected", "golden_test_goldens/test_check_string_updates.expected", ], @@ -24,10 +24,7 @@ index 7e91f23..08aa419 100644 py_library( name = "py_typecheck", srcs = ["py_typecheck.py"], -- deps = [ -- "@pypi//attrs", -- "@pypi//typing_extensions", -- ], +- deps = ["@pypi//typing_extensions"], ) py_test( @@ -42,7 +39,7 @@ index 7e91f23..08aa419 100644 ) py_library( -@@ -94,10 +83,7 @@ py_test( +@@ -91,10 +83,7 @@ py_test( name = "retrying_test", size = "small", srcs = ["retrying_test.py"], @@ -54,7 +51,7 @@ index 7e91f23..08aa419 100644 ) py_library( -@@ -108,39 +94,25 @@ py_library( +@@ -105,39 +94,25 @@ py_library( py_library( name = "structure", srcs = ["structure.py"], @@ -99,10 +96,10 @@ index 7e91f23..08aa419 100644 + deps = [":tracing"], ) diff --git federated_language/compiler/BUILD federated_language/compiler/BUILD -index 9752558..9dde47b 100644 +index be5266e..f5c40e8 100644 --- federated_language/compiler/BUILD +++ federated_language/compiler/BUILD -@@ -43,8 +43,6 @@ py_library( +@@ -42,8 +42,6 @@ py_library( "//federated_language/proto:array_py_pb2", "//federated_language/types:array_shape", "//federated_language/types:dtype_utils", @@ -111,7 +108,7 @@ index 9752558..9dde47b 100644 ], ) -@@ -55,9 +53,6 @@ py_test( +@@ -54,9 +52,6 @@ py_test( ":array", "//federated_language/proto:array_py_pb2", "//federated_language/proto:data_type_py_pb2", @@ -122,31 +119,24 @@ index 9752558..9dde47b 100644 ) @@ -137,8 +132,6 @@ py_test( + "//federated_language/types:computation_types", "//federated_language/types:placements", "//federated_language/types:type_analysis", - "//federated_language/types:type_test_utils", - "@pypi//absl_py", - "@pypi//numpy", ], ) -@@ -153,7 +146,6 @@ py_library( - "//federated_language/types:computation_types", - "//federated_language/types:placements", - "@protobuf//:protobuf_python", -- "@pypi//numpy", - ], - ) - -@@ -171,7 +163,6 @@ py_library( +@@ -156,8 +149,6 @@ py_library( "//federated_language/types:type_analysis", "//federated_language/types:typed_object", "@protobuf//:protobuf_python", - "@pypi//numpy", +- "@pypi//typing_extensions", ], ) -@@ -192,10 +183,6 @@ py_test( +@@ -178,10 +169,6 @@ py_test( "//federated_language/types:computation_types", "//federated_language/types:placements", "@protobuf//:protobuf_python", @@ -157,7 +147,7 @@ index 9752558..9dde47b 100644 ], ) -@@ -217,8 +204,6 @@ py_test( +@@ -203,8 +190,6 @@ py_test( "//federated_language/proto:computation_py_pb2", "//federated_language/types:computation_types", "//federated_language/types:type_factory", @@ -166,7 +156,7 @@ index 9752558..9dde47b 100644 ], ) -@@ -230,7 +215,6 @@ py_library( +@@ -216,7 +201,6 @@ py_library( "//federated_language/types:computation_types", "//federated_language/types:placements", "//federated_language/types:type_factory", @@ -174,7 +164,7 @@ index 9752558..9dde47b 100644 ], ) -@@ -238,10 +222,7 @@ py_test( +@@ -224,10 +208,7 @@ py_test( name = "intrinsic_defs_test", size = "small", srcs = ["intrinsic_defs_test.py"], @@ -186,16 +176,16 @@ index 9752558..9dde47b 100644 ) py_library( -@@ -267,8 +248,6 @@ py_test( - "//federated_language/proto:computation_py_pb2", +@@ -252,8 +233,6 @@ py_test( "//federated_language/types:computation_types", "//federated_language/types:placements", + "@protobuf//:protobuf_python", - "@pypi//absl_py", - "@pypi//numpy", ], ) -@@ -298,7 +277,5 @@ py_test( +@@ -282,7 +261,5 @@ py_test( ":tree_analysis", "//federated_language/types:computation_types", "//federated_language/types:placements", @@ -204,28 +194,19 @@ index 9752558..9dde47b 100644 ], ) diff --git federated_language/computation/BUILD federated_language/computation/BUILD -index 8c4e69a..da34426 100644 +index a688f3f..f99a3bc 100644 --- federated_language/computation/BUILD +++ federated_language/computation/BUILD -@@ -70,8 +70,6 @@ py_test( - "//federated_language/proto:computation_py_pb2", - "//federated_language/types:computation_types", - "//federated_language/types:type_test_utils", -- "@pypi//absl_py", -- "@pypi//numpy", - ], - ) - -@@ -98,8 +96,6 @@ py_test( - "//federated_language/compiler:computation_factory", +@@ -80,8 +80,6 @@ py_test( "//federated_language/context_stack:context_stack_impl", + "//federated_language/proto:computation_py_pb2", "//federated_language/types:computation_types", - "@pypi//absl_py", - "@pypi//numpy", ], ) -@@ -128,8 +124,6 @@ py_test( +@@ -112,8 +110,6 @@ py_test( "//federated_language/context_stack:context_stack_impl", "//federated_language/proto:computation_py_pb2", "//federated_language/types:computation_types", @@ -234,7 +215,7 @@ index 8c4e69a..da34426 100644 ], ) -@@ -154,8 +148,6 @@ py_test( +@@ -137,8 +133,6 @@ py_test( ":function_utils", "//federated_language/common_libs:structure", "//federated_language/types:computation_types", @@ -243,7 +224,7 @@ index 8c4e69a..da34426 100644 ], ) -@@ -180,7 +172,5 @@ py_test( +@@ -163,7 +157,5 @@ py_test( "//federated_language/proto:computation_py_pb2", "//federated_language/types:computation_types", "//federated_language/types:type_conversions", @@ -252,18 +233,10 @@ index 8c4e69a..da34426 100644 ], ) diff --git federated_language/context_stack/BUILD federated_language/context_stack/BUILD -index 65ac8f8..2a3f8f8 100644 +index 73814ef..29de268 100644 --- federated_language/context_stack/BUILD +++ federated_language/context_stack/BUILD @@ -63,7 +63,6 @@ py_test( - deps = [ - ":context_stack_impl", - ":context_stack_test_utils", -- "@pypi//absl_py", - ], - ) - -@@ -73,7 +72,6 @@ py_library( deps = [ ":context_base", ":context_stack_impl", @@ -271,15 +244,15 @@ index 65ac8f8..2a3f8f8 100644 ], ) -@@ -83,7 +81,6 @@ py_test( - deps = [ +@@ -83,7 +82,6 @@ py_test( + ":context_base", ":context_stack_impl", ":context_stack_test_utils", - "@pypi//absl_py", ], ) -@@ -100,7 +97,6 @@ py_test( +@@ -100,7 +98,6 @@ py_test( deps = [ ":context_stack_impl", ":get_context_stack", @@ -287,19 +260,19 @@ index 65ac8f8..2a3f8f8 100644 ], ) -@@ -127,7 +123,6 @@ py_test( +@@ -127,7 +124,6 @@ py_test( + ":context_base", ":context_stack_impl", - ":context_stack_test_utils", ":set_default_context", - "@pypi//absl_py", ], ) diff --git federated_language/execution_contexts/BUILD federated_language/execution_contexts/BUILD -index 01b0d32..58ce2a3 100644 +index e392531..2c1974b 100644 --- federated_language/execution_contexts/BUILD +++ federated_language/execution_contexts/BUILD -@@ -51,7 +51,6 @@ py_library( +@@ -50,7 +50,6 @@ py_library( "//federated_language/types:computation_types", "//federated_language/types:type_conversions", "//federated_language/types:typed_object", @@ -307,15 +280,15 @@ index 01b0d32..58ce2a3 100644 ], ) -@@ -62,7 +61,6 @@ py_test( +@@ -61,7 +60,6 @@ py_test( deps = [ ":async_execution_context", - "//federated_language/executors:executors_errors", + "//federated_language/executors:executor_base", - "@pypi//absl_py", ], ) -@@ -82,7 +80,6 @@ py_test( +@@ -81,7 +79,6 @@ py_test( deps = [ ":compiler_pipeline", "//federated_language/computation:computation_base", @@ -324,7 +297,7 @@ index 01b0d32..58ce2a3 100644 ) diff --git federated_language/executors/BUILD federated_language/executors/BUILD -index 885c348..56a8de2 100644 +index 42cc0c7..49bbcde 100644 --- federated_language/executors/BUILD +++ federated_language/executors/BUILD @@ -51,8 +51,6 @@ py_test( @@ -337,10 +310,10 @@ index 885c348..56a8de2 100644 ) diff --git federated_language/federated_context/BUILD federated_language/federated_context/BUILD -index 4f378fa..58c7415 100644 +index 94ee69e..7c1dbdc 100644 --- federated_language/federated_context/BUILD +++ federated_language/federated_context/BUILD -@@ -58,8 +58,6 @@ py_test( +@@ -57,8 +57,6 @@ py_test( "//federated_language/context_stack:get_context_stack", "//federated_language/context_stack:runtime_error_context", "//federated_language/types:computation_types", @@ -349,7 +322,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -91,8 +89,6 @@ py_test( +@@ -90,8 +88,6 @@ py_test( "//federated_language/context_stack:context_stack_impl", "//federated_language/types:computation_types", "//federated_language/types:placements", @@ -358,7 +331,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -121,8 +117,6 @@ py_test( +@@ -120,8 +116,6 @@ py_test( "//federated_language/computation:function_utils", "//federated_language/context_stack:context_stack_impl", "//federated_language/types:computation_types", @@ -367,7 +340,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -145,7 +139,6 @@ py_library( +@@ -144,7 +138,6 @@ py_library( "//federated_language/types:placements", "//federated_language/types:type_analysis", "//federated_language/types:type_factory", @@ -375,7 +348,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -162,8 +155,6 @@ py_test( +@@ -161,8 +154,6 @@ py_test( "//federated_language/context_stack:context_stack_test_utils", "//federated_language/types:computation_types", "//federated_language/types:placements", @@ -384,7 +357,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -186,7 +177,6 @@ py_library( +@@ -185,7 +176,6 @@ py_library( "//federated_language/types:placements", "//federated_language/types:type_conversions", "//federated_language/types:typed_object", @@ -392,7 +365,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -204,9 +194,6 @@ py_test( +@@ -203,9 +193,6 @@ py_test( "//federated_language/context_stack:context_stack_impl", "//federated_language/types:computation_types", "//federated_language/types:placements", @@ -402,7 +375,7 @@ index 4f378fa..58c7415 100644 ], ) -@@ -236,7 +223,5 @@ py_test( +@@ -235,7 +222,5 @@ py_test( "//federated_language/context_stack:context_stack_impl", "//federated_language/types:computation_types", "//federated_language/types:placements", @@ -411,10 +384,10 @@ index 4f378fa..58c7415 100644 ], ) diff --git federated_language/program/BUILD federated_language/program/BUILD -index 8defe61..f6e50a3 100644 +index 9d2715e..1bcc1a2 100644 --- federated_language/program/BUILD +++ federated_language/program/BUILD -@@ -72,8 +72,6 @@ py_test( +@@ -73,8 +73,6 @@ py_test( "//federated_language/context_stack:context_stack_impl", "//federated_language/types:computation_types", "//federated_language/types:placements", @@ -423,7 +396,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -83,7 +81,6 @@ py_library( +@@ -84,7 +82,6 @@ py_library( deps = [ ":release_manager", ":value_reference", @@ -431,7 +404,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -93,9 +90,6 @@ py_test( +@@ -94,9 +91,6 @@ py_test( deps = [ ":logging_release_manager", ":program_test_utils", @@ -441,7 +414,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -114,9 +108,6 @@ py_test( +@@ -115,9 +109,6 @@ py_test( deps = [ ":memory_release_manager", ":program_test_utils", @@ -451,7 +424,25 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -133,10 +124,7 @@ py_library( +@@ -134,7 +125,6 @@ py_library( + "//federated_language/types:computation_types", + "//federated_language/types:placements", + "//federated_language/types:type_conversions", +- "@pypi//dm_tree", + ], + ) + +@@ -151,9 +141,6 @@ py_test( + "//federated_language/federated_context:federated_computation", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -170,10 +157,7 @@ py_library( py_test( name = "program_state_manager_test", srcs = ["program_state_manager_test.py"], @@ -463,7 +454,7 @@ index 8defe61..f6e50a3 100644 ) py_library( -@@ -148,9 +136,6 @@ py_library( +@@ -185,9 +169,6 @@ py_library( "//federated_language/common_libs:py_typecheck", "//federated_language/common_libs:serializable", "//federated_language/types:computation_types", @@ -473,7 +464,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -161,8 +146,6 @@ py_library( +@@ -198,8 +179,6 @@ py_library( ":structure_utils", ":value_reference", "//federated_language/common_libs:py_typecheck", @@ -482,7 +473,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -172,9 +155,6 @@ py_test( +@@ -209,9 +188,6 @@ py_test( deps = [ ":program_test_utils", ":release_manager", @@ -492,7 +483,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -195,19 +175,13 @@ py_test( +@@ -232,19 +208,13 @@ py_test( ":program_test_utils", ":serialization_utils", "//federated_language/types:computation_types", @@ -513,7 +504,7 @@ index 8defe61..f6e50a3 100644 ) py_test( -@@ -216,9 +190,6 @@ py_test( +@@ -253,9 +223,6 @@ py_test( deps = [ ":program_test_utils", ":structure_utils", @@ -523,7 +514,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -229,7 +200,6 @@ py_library( +@@ -266,7 +233,6 @@ py_library( ":structure_utils", "//federated_language/types:computation_types", "//federated_language/types:typed_object", @@ -531,7 +522,7 @@ index 8defe61..f6e50a3 100644 ], ) -@@ -239,8 +209,5 @@ py_test( +@@ -276,8 +242,5 @@ py_test( deps = [ ":program_test_utils", ":value_reference", @@ -540,22 +531,11 @@ index 8defe61..f6e50a3 100644 - "@pypi//numpy", ], ) -diff --git federated_language/test/BUILD federated_language/test/BUILD -index 709665c..17c84a6 100644 ---- federated_language/test/BUILD -+++ federated_language/test/BUILD -@@ -47,6 +47,5 @@ py_test( - "//federated_language/federated_context:federated_computation", - "//federated_language/federated_context:intrinsics", - "//federated_language/types:placements", -- "@pypi//absl_py", - ], - ) diff --git federated_language/types/BUILD federated_language/types/BUILD -index a5bcb11..abe4c0b 100644 +index 5e985da..74e3b33 100644 --- federated_language/types/BUILD +++ federated_language/types/BUILD -@@ -53,7 +53,6 @@ py_test( +@@ -52,7 +52,6 @@ py_test( deps = [ ":array_shape", "//federated_language/proto:array_py_pb2", @@ -563,7 +543,7 @@ index a5bcb11..abe4c0b 100644 ], ) -@@ -68,9 +67,6 @@ py_library( +@@ -67,9 +66,6 @@ py_library( "//federated_language/common_libs:structure", "//federated_language/proto:array_py_pb2", "//federated_language/proto:computation_py_pb2", @@ -573,7 +553,7 @@ index a5bcb11..abe4c0b 100644 ], ) -@@ -98,31 +94,19 @@ py_test( +@@ -83,31 +79,19 @@ py_test( "//federated_language/common_libs:structure", "//federated_language/proto:computation_py_pb2", "//federated_language/proto:data_type_py_pb2", @@ -607,7 +587,7 @@ index a5bcb11..abe4c0b 100644 ) py_library( -@@ -134,10 +118,7 @@ py_test( +@@ -119,10 +103,7 @@ py_test( name = "placements_test", size = "small", srcs = ["placements_test.py"], @@ -619,26 +599,26 @@ index a5bcb11..abe4c0b 100644 ) py_library( -@@ -151,8 +132,6 @@ py_library( - ":type_transformations", +@@ -133,8 +114,6 @@ py_library( + ":computation_types", + ":placements", "//federated_language/common_libs:py_typecheck", - "//federated_language/common_libs:structure", - "@pypi//ml_dtypes", - "@pypi//numpy", ], ) -@@ -165,9 +144,6 @@ py_test( +@@ -146,9 +125,6 @@ py_test( + ":computation_types", ":placements", ":type_analysis", - "//federated_language/common_libs:structure", - "@pypi//absl_py", - "@pypi//ml_dtypes", - "@pypi//numpy", ], ) -@@ -180,9 +156,6 @@ py_library( +@@ -161,9 +137,6 @@ py_library( ":typed_object", "//federated_language/common_libs:py_typecheck", "//federated_language/common_libs:structure", @@ -648,7 +628,7 @@ index a5bcb11..abe4c0b 100644 ], ) -@@ -196,9 +169,6 @@ py_test( +@@ -177,9 +150,6 @@ py_test( ":type_conversions", ":typed_object", "//federated_language/common_libs:structure", @@ -658,7 +638,7 @@ index a5bcb11..abe4c0b 100644 ], ) -@@ -215,8 +185,6 @@ py_test( +@@ -196,8 +166,6 @@ py_test( deps = [ ":computation_types", ":type_factory", @@ -667,7 +647,7 @@ index a5bcb11..abe4c0b 100644 ], ) -@@ -243,8 +211,6 @@ py_test( +@@ -218,8 +186,6 @@ py_test( ":computation_types", ":placements", ":type_transformations",