diff --git a/.bazelrc b/.bazelrc index fd7cb10632..9417b611bf 100644 --- a/.bazelrc +++ b/.bazelrc @@ -27,11 +27,8 @@ build --compilation_mode=opt # build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true -# pybind_abseil does not include a __init__.py, which breaks dependencies when -# bazel produces an empty __init__.py and hides the symbols from the `.so` -# python extension. We must ask bazel not to generate these and rely on -# explicit __init__.py files. -# build --incompatible_default_to_explicit_init_py +# Do not automatically create `__init__.py` in the runfiles of Python targets. +build --incompatible_default_to_explicit_init_py # Haswell processor and later optimizations. This covers most processors deployed # today, includin Colab CPU runtimes. diff --git a/WORKSPACE b/WORKSPACE index 5359b946f0..a69c44bdc9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -87,6 +87,22 @@ http_archive( # }, # ) +http_archive( + name = "federated_language", + patches = [ + "//third_party/federated_language:proto_library_loads.patch", + "//third_party/federated_language:python_deps.patch", + # Must come after `python_deps.patch`, this patches the output of `python_deps.patch`. + "//third_party/federated_language:structure_visibility.patch", + ], + repo_mapping = { + "@protobuf": "@com_google_protobuf", + }, + sha256 = "de4bbfd93ee10c3797463d6a96911963d8969fb5ed0b264ae4f3f3088013fed4", + strip_prefix = "federated-language-5405fd4b2965e2f7c6c240b386f0540e4114818e", + url = "https://github.com/google-parfait/federated-language/archive/5405fd4b2965e2f7c6c240b386f0540e4114818e.tar.gz", +) + # The version of TensorFlow should match the version in # https://github.com/google-parfait/tensorflow-federated/blob/main/requirements.txt. http_archive( diff --git a/docs/deployment.md b/docs/deployment.md index bb3479ea48..47181251cb 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -14,7 +14,7 @@ There are two principal modes of deployment for TFF computations: * **Native backends**. We're going to refer to a backend as *native* if it is capable of interpreting the syntactic structure of TFF computations as defined in - [`computation.proto`](https://github.com/google-parfait/tensorflow-federated/blob/main/tensorflow_federated/proto/v0/computation.proto). + [`computation.proto`](https://github.com/google-parfait/federated-language/blob/main/tensorflow_federated/proto/computation.proto). A native backend does not necessarily have to support all language constructs or intrinsics. Native backends must implement one of the standard TFF *executor* interfaces, such as diff --git a/docs/design/compilation.md b/docs/design/compilation.md index 703d78768f..2168805453 100644 --- a/docs/design/compilation.md +++ b/docs/design/compilation.md @@ -33,13 +33,13 @@ support [Computations](#computation) backed by other external runtimes. ### `Computation` A -[pb.Computation](https://github.com/google-parfait/tensorflow-federated/blob/main/tensorflow_federated/proto/v0/computation.proto) +[pb.Computation](https://github.com/google-parfait/federated-language/blob/main/tensorflow_federated/proto/computation.proto) is the Proto or serialized representation of the [AST](#ast). #### TensorFlow Computation A -[pb.Computation](https://github.com/google-parfait/tensorflow-federated/blob/main/tensorflow_federated/proto/v0/computation.proto) +[pb.Computation](https://github.com/google-parfait/federated-language/blob/main/tensorflow_federated/proto/computation.proto) that represents a [Computations](#computation) that will be delegated to the [TensorFlow](execution.md#tensorflow) runtime. diff --git a/docs/federated_core.md b/docs/federated_core.md index 21e2eadc53..a8cd5c0a12 100644 --- a/docs/federated_core.md +++ b/docs/federated_core.md @@ -80,7 +80,7 @@ blocks such as `tff.federated_sum`, `tff.federated_reduce`, or TFF uses an internal language to represent federated computations, the syntax of which is defined by the serializable representation in -[computation.proto](https://github.com/google-parfait/tensorflow-federated/blob/main/tensorflow_federated/proto/v0/computation.proto). +[computation.proto](https://github.com/google-parfait/federated-language/blob/main/tensorflow_federated/proto/computation.proto). Users of FC API generally won't need to interact with this language directly, though. Rather, we provide a Python API (the `tff` namespace) that wraps arounds it as a way to define computations. diff --git a/examples/custom_data_backend/BUILD b/examples/custom_data_backend/BUILD index ce727262cd..d70272aeb3 100644 --- a/examples/custom_data_backend/BUILD +++ b/examples/custom_data_backend/BUILD @@ -6,12 +6,11 @@ cc_library( hdrs = ["data_backend_example.h"], deps = [ "//tensorflow_federated/cc/core/impl/executors:data_backend", - "//tensorflow_federated/cc/core/impl/executors:status_macros", "//tensorflow_federated/cc/core/impl/executors:tensorflow_utils", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", ], diff --git a/examples/custom_data_backend/data_backend_example.cc b/examples/custom_data_backend/data_backend_example.cc index c8e6a6a7bd..b2881923cd 100644 --- a/examples/custom_data_backend/data_backend_example.cc +++ b/examples/custom_data_backend/data_backend_example.cc @@ -19,17 +19,17 @@ limitations under the License #include #include "google/protobuf/any.pb.h" +#include "federated_language/proto/array.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h" -#include "tensorflow_federated/proto/v0/array.pb.h" namespace tensorflow_federated_examples { namespace { -using ::tensorflow_federated::v0::Data; -using ::tensorflow_federated::v0::Type; +using ::federated_language::Data; +using ::federated_language::Type; using ::tensorflow_federated::v0::Value; // Constant URIs and values resolved by `DataBackendExample`. diff --git a/examples/custom_data_backend/data_backend_example.h b/examples/custom_data_backend/data_backend_example.h index 3b4dacb23c..38f8dfd75e 100644 --- a/examples/custom_data_backend/data_backend_example.h +++ b/examples/custom_data_backend/data_backend_example.h @@ -17,8 +17,8 @@ limitations under the License #define THIRD_PARTY_TENSORFLOW_FEDERATED_EXAMPLES_CUSTOM_DATA_BACKEND_DATA_BACKEND_EXAMPLE_H_ #include "absl/status/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/data_backend.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated_examples { @@ -26,10 +26,9 @@ namespace tensorflow_federated_examples { // An example implementation of `DataBackend` used to show Python interop. class DataBackendExample : public tensorflow_federated::DataBackend { public: - absl::Status ResolveToValue( - const tensorflow_federated::v0::Data& data_reference, - const tensorflow_federated::v0::Type& data_type, - tensorflow_federated::v0::Value& value_out) final; + absl::Status ResolveToValue(const federated_language::Data& data_reference, + const federated_language::Type& data_type, + tensorflow_federated::v0::Value& value_out) final; }; } // namespace tensorflow_federated_examples diff --git a/tensorflow_federated/BUILD b/tensorflow_federated/BUILD index 5eba5ec834..ff6bc8e225 100644 --- a/tensorflow_federated/BUILD +++ b/tensorflow_federated/BUILD @@ -20,19 +20,13 @@ py_library( "//tensorflow_federated/python/core/environments/jax", "//tensorflow_federated/python/core/environments/tensorflow", "//tensorflow_federated/python/core/framework", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/federated_context:value_impl", "//tensorflow_federated/python/core/impl/types", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:typed_object", "//tensorflow_federated/python/core/templates", "//tensorflow_federated/python/core/test", "//tensorflow_federated/python/learning", "//tensorflow_federated/python/program", "//tensorflow_federated/python/simulation", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/__init__.py b/tensorflow_federated/__init__.py index 86127c5470..c723180760 100644 --- a/tensorflow_federated/__init__.py +++ b/tensorflow_federated/__init__.py @@ -29,38 +29,38 @@ from tensorflow_federated.python.core.environments import jax from tensorflow_federated.python.core.environments import tensorflow from tensorflow_federated.python.core.impl import types -from tensorflow_federated.python.core.impl.computation.computation_base import Computation -from tensorflow_federated.python.core.impl.federated_context.federated_computation import federated_computation -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_aggregate -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_broadcast -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_eval -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_map -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_max -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_mean -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_min -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_select -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_sum -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_secure_sum_bitwidth -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_select -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_sum -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_value -from tensorflow_federated.python.core.impl.federated_context.intrinsics import federated_zip -from tensorflow_federated.python.core.impl.federated_context.intrinsics import sequence_map -from tensorflow_federated.python.core.impl.federated_context.intrinsics import sequence_reduce -from tensorflow_federated.python.core.impl.federated_context.intrinsics import sequence_sum -from tensorflow_federated.python.core.impl.federated_context.value_impl import to_value -from tensorflow_federated.python.core.impl.federated_context.value_impl import Value -from tensorflow_federated.python.core.impl.types.computation_types import FederatedType -from tensorflow_federated.python.core.impl.types.computation_types import FunctionType -from tensorflow_federated.python.core.impl.types.computation_types import SequenceType -from tensorflow_federated.python.core.impl.types.computation_types import StructType -from tensorflow_federated.python.core.impl.types.computation_types import StructWithPythonType -from tensorflow_federated.python.core.impl.types.computation_types import TensorType -from tensorflow_federated.python.core.impl.types.computation_types import to_type -from tensorflow_federated.python.core.impl.types.computation_types import Type -from tensorflow_federated.python.core.impl.types.placements import CLIENTS -from tensorflow_federated.python.core.impl.types.placements import SERVER -from tensorflow_federated.python.core.impl.types.typed_object import TypedObject +from federated_language import Computation +from federated_language import federated_computation +from federated_language import federated_aggregate +from federated_language import federated_broadcast +from federated_language import federated_eval +from federated_language import federated_map +from federated_language import federated_max +from federated_language import federated_mean +from federated_language import federated_min +from federated_language import federated_secure_select +from federated_language import federated_secure_sum +from federated_language import federated_secure_sum_bitwidth +from federated_language import federated_select +from federated_language import federated_sum +from federated_language import federated_value +from federated_language import federated_zip +from federated_language import sequence_map +from federated_language import sequence_reduce +from federated_language import sequence_sum +from federated_language import to_value +from federated_language import Value +from federated_language import FederatedType +from federated_language import FunctionType +from federated_language import SequenceType +from federated_language import StructType +from federated_language import StructWithPythonType +from federated_language import TensorType +from federated_language import to_type +from federated_language import Type +from federated_language import CLIENTS +from federated_language import SERVER +from federated_language import TypedObject from tensorflow_federated.version import __version__ # pylint: enable=g-importing-member diff --git a/tensorflow_federated/cc/core/impl/executors/BUILD b/tensorflow_federated/cc/core/impl/executors/BUILD index 0ddb6c752b..5f5e698eaf 100644 --- a/tensorflow_federated/cc/core/impl/executors/BUILD +++ b/tensorflow_federated/cc/core/impl/executors/BUILD @@ -25,13 +25,13 @@ cc_library( name = "array_shape_test_utils", testonly = True, hdrs = ["array_shape_test_utils.h"], - deps = ["//tensorflow_federated/proto/v0:array_cc_proto"], + deps = ["@federated_language//federated_language/proto:array_cc_proto"], ) cc_library( name = "array_shape_utils", hdrs = ["array_shape_utils.h"], - deps = ["//tensorflow_federated/proto/v0:array_cc_proto"], + deps = ["@federated_language//federated_language/proto:array_cc_proto"], ) cc_library( @@ -39,12 +39,12 @@ cc_library( testonly = True, hdrs = ["array_test_utils.h"], deps = [ - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", ], ) @@ -73,7 +73,6 @@ cc_library( ":status_macros", ":threading", ":value_validation", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -83,6 +82,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core/platform:macros", ], ) @@ -101,11 +101,11 @@ cc_test( ":value_test_utils", "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/core:tensorflow", ], ) @@ -113,7 +113,7 @@ cc_test( cc_library( name = "computations", hdrs = ["computations.h"], - deps = ["//tensorflow_federated/proto/v0:computation_cc_proto"], + deps = ["@federated_language//federated_language/proto:computation_cc_proto"], ) cc_library( @@ -122,10 +122,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status_macros", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -155,9 +155,9 @@ cc_test( ":value_test_utils", "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", ], ) @@ -216,12 +216,12 @@ cc_library( ":dataset_from_tensor_structures", ":status_macros", ":tensorflow_utils", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:portable_gif_internal", "@org_tensorflow//tensorflow/core/data:standalone", @@ -253,7 +253,6 @@ cc_library( deps = [ ":status_macros", ":tensorflow_utils", - "//tensorflow_federated/proto/v0:computation_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -262,6 +261,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/c:tf_datatype", "@org_tensorflow//tensorflow/c:tf_status_headers", "@org_tensorflow//tensorflow/c/eager:c_api", @@ -311,7 +311,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status_macros", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -323,6 +322,7 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/utility", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core/profiler/lib:traceme", ], ) @@ -344,13 +344,13 @@ pybind_extension( ":streaming_remote_executor", ":tensorflow_executor", ":xla_executor", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core/platform:status", "@org_tensorflow//tensorflow/core/platform:strcat", @@ -369,7 +369,6 @@ cc_library( ":cardinalities", ":executor", ":status_conversion", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_grpc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_github_grpc_grpc//:grpc++", @@ -381,6 +380,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -397,12 +397,12 @@ cc_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -455,7 +455,6 @@ cc_library( ":tensor_serialization", ":threading", ":value_validation", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -465,6 +464,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core/platform:macros", @@ -533,10 +533,10 @@ cc_library( deps = [ ":data_backend", "//tensorflow_federated/cc/testing:protobuf_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -579,12 +579,12 @@ cc_library( deps = [ ":executor", ":status_macros", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:tensorflow", ], ) @@ -602,8 +602,6 @@ cc_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -611,6 +609,8 @@ cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/cc:array_ops", "@org_tensorflow//tensorflow/cc:math_ops", "@org_tensorflow//tensorflow/cc:scope", @@ -631,7 +631,6 @@ cc_library( ":status_conversion", ":status_macros", ":threading", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_grpc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_github_grpc_grpc//:grpc++", @@ -640,6 +639,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -681,7 +681,6 @@ cc_library( ":struct_traversal_order", ":tensor_serialization", ":threading", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", @@ -689,6 +688,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core/data:standalone", "@org_tensorflow//tensorflow/core/platform:status", @@ -712,8 +712,8 @@ cc_test( ":value_test_utils", "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", "@com_google_absl//absl/status", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:tensorflow", ], ) @@ -795,7 +795,6 @@ cc_library( ":status_macros", ":threading", ":type_utils", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_grpc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_github_grpc_grpc//:grpc++", @@ -805,6 +804,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -823,8 +823,6 @@ cc_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_grpc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_github_grpc_grpc//:grpc++", @@ -837,6 +835,8 @@ cc_test( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/cc:array_ops", "@org_tensorflow//tensorflow/cc:math_ops", ], @@ -846,10 +846,10 @@ cc_library( name = "struct_traversal_order", hdrs = ["struct_traversal_order.h"], deps = [ - "//tensorflow_federated/proto/v0:computation_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -874,12 +874,12 @@ cc_library( ":array_shape_utils", ":status_macros", ":tensorflow_utils", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", ], @@ -899,7 +899,6 @@ cc_library( ":tensor_serialization", ":tensorflow_utils", ":threading", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -910,6 +909,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core/common_runtime:core", @@ -939,15 +939,15 @@ tff_cc_cpu_gpu_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/c:tf_status_headers", "@org_tensorflow//tensorflow/c/eager:c_api", "@org_tensorflow//tensorflow/c/eager:tfe_context_internal", @@ -969,8 +969,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status_macros", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -978,6 +976,8 @@ cc_library( "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:portable_gif_internal", "@org_tensorflow//tensorflow/core:protos_all_cc", @@ -994,12 +994,12 @@ cc_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@eigen_archive//:eigen3", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/core:framework", "@org_tensorflow//tensorflow/core:portable_gif_internal", "@org_tensorflow//tensorflow/core:protos_all_cc", @@ -1044,9 +1044,9 @@ cc_library( testonly = True, hdrs = ["type_test_utils.h"], deps = [ - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", ], ) @@ -1056,13 +1056,13 @@ cc_library( hdrs = ["type_utils.h"], deps = [ ":status_macros", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) @@ -1078,15 +1078,15 @@ cc_library( ":status_macros", ":tensor_serialization", "//tensorflow_federated/cc/testing:protobuf_matchers", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/cc:dataset_ops_internal", "@org_tensorflow//tensorflow/core:core_cpu_base", "@org_tensorflow//tensorflow/core:framework", @@ -1102,11 +1102,11 @@ cc_library( deps = [ ":cardinalities", ":federated_intrinsics", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -1119,10 +1119,10 @@ cc_test( ":value_validation", "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:computation_cc_proto", "//tensorflow_federated/proto/v0:executor_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", ], ) @@ -1136,12 +1136,12 @@ cc_library( ":tensor_serialization", ":threading", ":xla_utils", - "//tensorflow_federated/proto/v0:computation_cc_proto", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@federated_language//federated_language/proto:computation_cc_proto", "@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit", # buildcleaner: keep # Linking in this dependency ensures that XLA can compile its code for the CPU host. "@org_tensorflow//tensorflow/compiler/tf2xla:common", "@org_tensorflow//tensorflow/compiler/xla:literal", @@ -1176,14 +1176,14 @@ tff_cc_cpu_gpu_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/compiler/tf2xla:common", "@org_tensorflow//tensorflow/compiler/xla:shape_util", "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1213,14 +1213,14 @@ cc_test( "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:protobuf_matchers", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/compiler/jit:xla_gpu_jit", # buildcleaner: keep # Linking in this dependency ensures that XLA can compile its code for the GPU. "@org_tensorflow//tensorflow/compiler/tf2xla:common", "@org_tensorflow//tensorflow/compiler/xla:shape_util", @@ -1242,12 +1242,12 @@ cc_library( hdrs = ["xla_utils.h"], deps = [ ":status_macros", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/compiler/xla:literal", "@org_tensorflow//tensorflow/compiler/xla:shape_util", "@org_tensorflow//tensorflow/compiler/xla:types", @@ -1264,12 +1264,12 @@ cc_test( ":xla_utils", "//tensorflow_federated/cc/testing:oss_test_main", "//tensorflow_federated/cc/testing:status_matchers", - "//tensorflow_federated/proto/v0:array_cc_proto", - "//tensorflow_federated/proto/v0:computation_cc_proto", - "//tensorflow_federated/proto/v0:data_type_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@eigen_archive//:eigen3", + "@federated_language//federated_language/proto:array_cc_proto", + "@federated_language//federated_language/proto:computation_cc_proto", + "@federated_language//federated_language/proto:data_type_cc_proto", "@org_tensorflow//tensorflow/compiler/xla:literal", "@org_tensorflow//tensorflow/compiler/xla:literal_util", "@org_tensorflow//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow_federated/cc/core/impl/executors/array_shape_test_utils.h b/tensorflow_federated/cc/core/impl/executors/array_shape_test_utils.h index cde7b789b1..b2256aaf6f 100644 --- a/tensorflow_federated/cc/core/impl/executors/array_shape_test_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/array_shape_test_utils.h @@ -19,20 +19,21 @@ limitations under the License #include #include -#include "tensorflow_federated/proto/v0/array.pb.h" +#include "federated_language/proto/array.pb.h" namespace tensorflow_federated { namespace testing { -inline v0::ArrayShape CreateArrayShape(std::initializer_list dims, - bool unknown_rank) { - v0::ArrayShape shape_pb; +inline federated_language::ArrayShape CreateArrayShape( + std::initializer_list dims, bool unknown_rank) { + federated_language::ArrayShape shape_pb; shape_pb.mutable_dim()->Assign(dims.begin(), dims.end()); shape_pb.set_unknown_rank(unknown_rank); return shape_pb; } -inline v0::ArrayShape CreateArrayShape(std::initializer_list dims) { +inline federated_language::ArrayShape CreateArrayShape( + std::initializer_list dims) { return CreateArrayShape(dims, false); } diff --git a/tensorflow_federated/cc/core/impl/executors/array_shape_utils.h b/tensorflow_federated/cc/core/impl/executors/array_shape_utils.h index 0e1599e07b..6b2f9367b1 100644 --- a/tensorflow_federated/cc/core/impl/executors/array_shape_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/array_shape_utils.h @@ -16,11 +16,11 @@ limitations under the License #ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_ARRAY_SHAPE_UTILS_H_ #define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_ARRAY_SHAPE_UTILS_H_ -#include "tensorflow_federated/proto/v0/array.pb.h" +#include "federated_language/proto/array.pb.h" namespace tensorflow_federated { -inline bool IsScalar(const v0::ArrayShape& shape) { +inline bool IsScalar(const federated_language::ArrayShape& shape) { return shape.dim().empty() && !shape.unknown_rank(); } diff --git a/tensorflow_federated/cc/core/impl/executors/array_test_utils.h b/tensorflow_federated/cc/core/impl/executors/array_test_utils.h index f368fcb027..062ea168f3 100644 --- a/tensorflow_federated/cc/core/impl/executors/array_test_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/array_test_utils.h @@ -26,71 +26,71 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "third_party/eigen3/Eigen/Core" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/data_type.pb.h" namespace tensorflow_federated { namespace testing { template -inline absl::StatusOr CreateArray(v0::DataType dtype, - v0::ArrayShape shape_pb, - std::initializer_list values) { - v0::Array array_pb; +inline absl::StatusOr CreateArray( + federated_language::DataType dtype, federated_language::ArrayShape shape_pb, + std::initializer_list values) { + federated_language::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); switch (dtype) { - case v0::DataType::DT_BOOL: { + case federated_language::DataType::DT_BOOL: { array_pb.mutable_bool_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_INT8: { + case federated_language::DataType::DT_INT8: { array_pb.mutable_int8_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_INT16: { + case federated_language::DataType::DT_INT16: { array_pb.mutable_int16_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_INT32: { + case federated_language::DataType::DT_INT32: { array_pb.mutable_int32_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_INT64: { + case federated_language::DataType::DT_INT64: { array_pb.mutable_int64_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_UINT8: { + case federated_language::DataType::DT_UINT8: { array_pb.mutable_uint8_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_UINT16: { + case federated_language::DataType::DT_UINT16: { array_pb.mutable_uint16_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_UINT32: { + case federated_language::DataType::DT_UINT32: { array_pb.mutable_uint32_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_UINT64: { + case federated_language::DataType::DT_UINT64: { array_pb.mutable_uint64_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_FLOAT: { + case federated_language::DataType::DT_FLOAT: { array_pb.mutable_float32_list()->mutable_value()->Assign(values.begin(), values.end()); break; } - case v0::DataType::DT_DOUBLE: { + case federated_language::DataType::DT_DOUBLE: { array_pb.mutable_float64_list()->mutable_value()->Assign(values.begin(), values.end()); break; @@ -103,14 +103,14 @@ inline absl::StatusOr CreateArray(v0::DataType dtype, } // Overload for Eigen::half. -inline absl::StatusOr CreateArray( - v0::DataType dtype, v0::ArrayShape shape_pb, +inline absl::StatusOr CreateArray( + federated_language::DataType dtype, federated_language::ArrayShape shape_pb, std::initializer_list values) { - v0::Array array_pb; + federated_language::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); switch (dtype) { - case v0::DataType::DT_HALF: { + case federated_language::DataType::DT_HALF: { auto size = values.size(); array_pb.mutable_float16_list()->mutable_value()->Reserve(size); for (auto element : values) { @@ -128,20 +128,20 @@ inline absl::StatusOr CreateArray( // Overload for complex. template -inline absl::StatusOr CreateArray( - v0::DataType dtype, v0::ArrayShape shape_pb, +inline absl::StatusOr CreateArray( + federated_language::DataType dtype, federated_language::ArrayShape shape_pb, std::initializer_list> values) { - v0::Array array_pb; + federated_language::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); const T* begin = reinterpret_cast(values.begin()); switch (dtype) { - case v0::DataType::DT_COMPLEX64: { + case federated_language::DataType::DT_COMPLEX64: { array_pb.mutable_complex64_list()->mutable_value()->Assign( begin, begin + values.size() * 2); break; } - case v0::DataType::DT_COMPLEX128: { + case federated_language::DataType::DT_COMPLEX128: { array_pb.mutable_complex128_list()->mutable_value()->Assign( begin, begin + values.size() * 2); break; @@ -154,14 +154,14 @@ inline absl::StatusOr CreateArray( } // Overload for Eigen::bfloat16. -inline absl::StatusOr CreateArray( - v0::DataType dtype, v0::ArrayShape shape_pb, +inline absl::StatusOr CreateArray( + federated_language::DataType dtype, federated_language::ArrayShape shape_pb, std::initializer_list values) { - v0::Array array_pb; + federated_language::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); switch (dtype) { - case v0::DataType::DT_BFLOAT16: { + case federated_language::DataType::DT_BFLOAT16: { auto size = values.size(); array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size); for (auto element : values) { @@ -178,14 +178,14 @@ inline absl::StatusOr CreateArray( } // Overload for string. -inline absl::StatusOr CreateArray( - v0::DataType dtype, v0::ArrayShape shape_pb, +inline absl::StatusOr CreateArray( + federated_language::DataType dtype, federated_language::ArrayShape shape_pb, std::initializer_list values) { - v0::Array array_pb; + federated_language::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); switch (dtype) { - case v0::DT_STRING: { + case federated_language::DT_STRING: { array_pb.mutable_string_list()->mutable_value()->Assign(values.begin(), values.end()); break; @@ -197,10 +197,10 @@ inline absl::StatusOr CreateArray( return array_pb; } -inline absl::StatusOr CreateArrayContent(v0::DataType dtype, - v0::ArrayShape shape_pb, - std::string_view content) { - v0::Array array_pb; +inline absl::StatusOr CreateArrayContent( + federated_language::DataType dtype, federated_language::ArrayShape shape_pb, + std::string_view content) { + federated_language::Array array_pb; array_pb.set_dtype(dtype); array_pb.mutable_shape()->Swap(&shape_pb); *array_pb.mutable_content() = content; diff --git a/tensorflow_federated/cc/core/impl/executors/composing_executor.cc b/tensorflow_federated/cc/core/impl/executors/composing_executor.cc index 7bed098132..d6d5f976d5 100644 --- a/tensorflow_federated/cc/core/impl/executors/composing_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/composing_executor.cc @@ -37,6 +37,7 @@ limitations under the License #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/computations.h" @@ -45,7 +46,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" #include "tensorflow_federated/cc/core/impl/executors/value_validation.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -182,7 +182,7 @@ struct TypedFederatedIntrinsic { FederatedIntrinsic federated_intrinsic; // The type signature of the FederatedIntrinsic. - v0::FunctionType type_signature; + federated_language::FunctionType type_signature; }; using ValueVariant = std::variant; @@ -589,7 +589,7 @@ class ComposingExecutor : public ExecutorBase { } absl::StatusOr CallIntrinsicEvalAtClients( - ExecutorValue&& arg, const v0::FunctionType& type_pb) { + ExecutorValue&& arg, const federated_language::FunctionType& type_pb) { auto traceme = Trace("CallIntrinsicEvalAtClients"); auto fn_to_eval = TFF_TRY(arg.GetUnplacedFunctionProto("federated_eval_at_clients_fn")); @@ -612,7 +612,7 @@ class ComposingExecutor : public ExecutorBase { } absl::StatusOr CallIntrinsicAggregate( - ExecutorValue&& arg, const v0::FunctionType& type_pb) { + ExecutorValue&& arg, const federated_language::FunctionType& type_pb) { auto traceme = Trace("CallIntrinsicAggregate"); TFF_TRY(arg.CheckLenForUseAsArgument("federated_aggregate", 5)); const auto& value = arg.structure()->at(0); @@ -723,7 +723,7 @@ class ComposingExecutor : public ExecutorBase { } absl::StatusOr CallIntrinsicMap( - ExecutorValue&& arg, const v0::FunctionType& type_pb) { + ExecutorValue&& arg, const federated_language::FunctionType& type_pb) { auto traceme = Trace("CallIntrinsicMap"); TFF_TRY(arg.CheckLenForUseAsArgument("federated_map", 2)); const auto& fn = arg.structure()->at(0); @@ -761,7 +761,7 @@ class ComposingExecutor : public ExecutorBase { } absl::StatusOr CallIntrinsicSelect_( - ExecutorValue&& arg, const v0::FunctionType& type_pb) { + ExecutorValue&& arg, const federated_language::FunctionType& type_pb) { auto traceme = Trace("CallIntrinsicSelect_"); TFF_TRY(arg.CheckLenForUseAsArgument("federated_select", 4)); const ExecutorValue& keys = arg.structure()->at(0); @@ -847,7 +847,7 @@ class ComposingExecutor : public ExecutorBase { } absl::StatusOr CallIntrinsicZipAtClients( - ExecutorValue&& arg, const v0::FunctionType& type_pb) { + ExecutorValue&& arg, const federated_language::FunctionType& type_pb) { auto traceme = Trace("CallIntrinsicZipAtClients"); v0::Value zip_at_clients; zip_at_clients.mutable_computation() @@ -907,7 +907,8 @@ class ComposingExecutor : public ExecutorBase { absl::StatusOr CallFederatedIntrinsic( TypedFederatedIntrinsic typed_intrinsic, ExecutorValue arg) { FederatedIntrinsic function = typed_intrinsic.federated_intrinsic; - const v0::FunctionType& type_pb = typed_intrinsic.type_signature; + const federated_language::FunctionType& type_pb = + typed_intrinsic.type_signature; switch (function) { case FederatedIntrinsic::VALUE_AT_SERVER: { return CallIntrinsicValueAtServer(std::move(arg)); @@ -1005,7 +1006,8 @@ class ComposingExecutor : public ExecutorBase { for (int32_t i = 0; i < total_clients_; i++) { *values_pb->Add() = null_value; } - v0::FederatedType* type_pb = federated_pb->mutable_type(); + federated_language::FederatedType* type_pb = + federated_pb->mutable_type(); // All-equal-ness is not stored, so must be assumed to be false. // If the Python type system expects the value to be all-equal, it can // simply extract the first element in the list. @@ -1036,7 +1038,8 @@ class ComposingExecutor : public ExecutorBase { } case ExecutorValue::ValueType::SERVER: { v0::Value_Federated* federated_pb = value_pb->mutable_federated(); - v0::FederatedType* type_pb = federated_pb->mutable_type(); + federated_language::FederatedType* type_pb = + federated_pb->mutable_type(); // Server placement is assumed to be of cardinality one, and so must // be all-equal. type_pb->set_all_equal(true); diff --git a/tensorflow_federated/cc/core/impl/executors/composing_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/composing_executor_test.cc index 460388fa91..508f859f57 100644 --- a/tensorflow_federated/cc/core/impl/executors/composing_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/composing_executor_test.cc @@ -26,6 +26,8 @@ limitations under the License #include "googletest/include/gtest/gtest.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/computations.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" @@ -34,8 +36,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/type_utils.h" #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -339,21 +339,21 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedAggregate) { v0::Value result_from_child = ServerV(TensorV("result from child")); v0::Value final_result_unfed = TensorV("final result"); - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_federated(); parameter_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); parameter_type_pb->set_all_equal(false); parameter_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_STRING); - v0::FederatedType* result_type_pb = + federated_language::DataType::DT_STRING); + federated_language::FederatedType* result_type_pb = intrinsic_type_pb.mutable_result()->mutable_federated(); result_type_pb->mutable_placement()->mutable_value()->set_uri( kServerUri.data(), kServerUri.size()); result_type_pb->set_all_equal(true); result_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_STRING); + federated_language::DataType::DT_STRING); for (const auto& child : mock_children_) { auto child_value = child->ExpectCreateValue(value); @@ -462,14 +462,14 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedMapAtClients) { std::vector client_vals_out; v0::Value fn = TensorV(24601); - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_federated(); parameter_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); parameter_type_pb->set_all_equal(false); parameter_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); + federated_language::DataType::DT_INT32); *intrinsic_type_pb.mutable_result()->mutable_federated() = *parameter_type_pb; for (uint32_t i = 0; i < mock_children_.size(); i++) { @@ -508,14 +508,14 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedMapAllEqualAtClients) { std::vector client_vals_out; v0::Value fn = TensorV(24601); - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_federated(); parameter_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); parameter_type_pb->set_all_equal(true); parameter_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); + federated_language::DataType::DT_INT32); *intrinsic_type_pb.mutable_result()->mutable_federated() = *parameter_type_pb; for (uint32_t i = 0; i < mock_children_.size(); i++) { @@ -593,14 +593,14 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedEvalAtClients) { v0::Value fn = TensorV(22); std::vector client_results; - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* result_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* result_type_pb = intrinsic_type_pb.mutable_result()->mutable_federated(); result_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); result_type_pb->set_all_equal(false); result_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); + federated_language::DataType::DT_INT32); for (uint32_t i = 0; i < mock_children_.size(); i++) { const auto& child = mock_children_[i]; @@ -647,8 +647,8 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedSelect) { v0::Value select_fn = TensorV("select_fn"); mock_server_->ExpectCreateMaterialize(TensorV("server_val")); - v0::FunctionType intrinsic_type_pb; - v0::StructType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::StructType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_struct_(); *parameter_type_pb->add_element()->mutable_value() = TFF_ASSERT_OK(InferTypeFromValue(keys)); @@ -708,21 +708,21 @@ TEST_F(ComposingExecutorTest, unplaced_id, // select_fn })); - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_federated(); parameter_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); parameter_type_pb->set_all_equal(false); parameter_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); - v0::FederatedType* result_type_pb = + federated_language::DataType::DT_INT32); + federated_language::FederatedType* result_type_pb = intrinsic_type_pb.mutable_result()->mutable_federated(); result_type_pb->mutable_placement()->mutable_value()->set_uri( kServerUri.data(), kServerUri.size()); result_type_pb->set_all_equal(true); result_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); + federated_language::DataType::DT_INT32); auto agg = TFF_ASSERT_OK( test_executor_->CreateValue(FederatedAggregateV(intrinsic_type_pb))); auto res = TFF_ASSERT_OK(test_executor_->CreateCall(agg, arg)); @@ -791,20 +791,20 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClientsFlat) { auto merged_struct = StructV({TensorV(1), TensorV(2)}); v0::Value merged = ClientsV({merged_struct}, true); - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_federated(); - v0::StructType* parameter_struct_type_pb = + federated_language::StructType* parameter_struct_type_pb = parameter_type_pb->mutable_member()->mutable_struct_(); - v0::FederatedType* result_type_pb = + federated_language::FederatedType* result_type_pb = intrinsic_type_pb.mutable_result()->mutable_federated(); result_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); result_type_pb->set_all_equal(false); - v0::StructType* result_struct_type_pb = + federated_language::StructType* result_struct_type_pb = result_type_pb->mutable_member()->mutable_struct_(); for (int i = 0; i < 2; ++i) { - v0::FederatedType* parameter_element_type_pb = + federated_language::FederatedType* parameter_element_type_pb = parameter_struct_type_pb->add_element() ->mutable_value() ->mutable_federated(); @@ -812,11 +812,11 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClientsFlat) { kClientsUri.data(), kClientsUri.size()); parameter_element_type_pb->set_all_equal(false); parameter_element_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); + federated_language::DataType::DT_INT32); result_struct_type_pb->add_element() ->mutable_value() ->mutable_tensor() - ->set_dtype(v0::DataType::DT_INT32); + ->set_dtype(federated_language::DataType::DT_INT32); } for (const auto& child : mock_children_) { @@ -848,20 +848,20 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClientsNested) { auto merged_struct = StructV({TensorV(1), StructV({TensorV(2)})}); v0::Value merged = ClientsV({merged_struct}, true); - v0::FunctionType intrinsic_type_pb; - v0::FederatedType* parameter_type_pb = + federated_language::FunctionType intrinsic_type_pb; + federated_language::FederatedType* parameter_type_pb = intrinsic_type_pb.mutable_parameter()->mutable_federated(); - v0::StructType* parameter_struct_type_pb = + federated_language::StructType* parameter_struct_type_pb = parameter_type_pb->mutable_member()->mutable_struct_(); - v0::FederatedType* result_type_pb = + federated_language::FederatedType* result_type_pb = intrinsic_type_pb.mutable_result()->mutable_federated(); result_type_pb->mutable_placement()->mutable_value()->set_uri( kClientsUri.data(), kClientsUri.size()); result_type_pb->set_all_equal(false); - v0::StructType* result_struct_type_pb = + federated_language::StructType* result_struct_type_pb = result_type_pb->mutable_member()->mutable_struct_(); for (int i = 0; i < 2; ++i) { - v0::FederatedType* parameter_element_type_pb = + federated_language::FederatedType* parameter_element_type_pb = parameter_struct_type_pb->add_element() ->mutable_value() ->mutable_federated(); @@ -869,11 +869,11 @@ TEST_F(ComposingExecutorTest, CreateCallFederatedZipAtClientsNested) { kClientsUri.data(), kClientsUri.size()); parameter_element_type_pb->set_all_equal(false); parameter_element_type_pb->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_INT32); + federated_language::DataType::DT_INT32); result_struct_type_pb->add_element() ->mutable_value() ->mutable_tensor() - ->set_dtype(v0::DataType::DT_INT32); + ->set_dtype(federated_language::DataType::DT_INT32); } for (const auto& child : mock_children_) { diff --git a/tensorflow_federated/cc/core/impl/executors/computations.h b/tensorflow_federated/cc/core/impl/executors/computations.h index cc93415fa4..521357cc19 100644 --- a/tensorflow_federated/cc/core/impl/executors/computations.h +++ b/tensorflow_federated/cc/core/impl/executors/computations.h @@ -16,13 +16,13 @@ limitations under the License #ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_COMPUTATIONS_H_ #define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_COMPUTATIONS_H_ -#include "tensorflow_federated/proto/v0/computation.pb.h" +#include "federated_language/proto/computation.pb.h" namespace tensorflow_federated { -inline v0::Computation IdentityComp() { - v0::Computation comp; - v0::Lambda* lambda = comp.mutable_lambda(); +inline federated_language::Computation IdentityComp() { + federated_language::Computation comp; + federated_language::Lambda* lambda = comp.mutable_lambda(); *lambda->mutable_parameter_name() = "x"; *lambda->mutable_result()->mutable_reference()->mutable_name() = "x"; return comp; diff --git a/tensorflow_federated/cc/core/impl/executors/data_backend.h b/tensorflow_federated/cc/core/impl/executors/data_backend.h index 126a14eef7..c676eae99f 100644 --- a/tensorflow_federated/cc/core/impl/executors/data_backend.h +++ b/tensorflow_federated/cc/core/impl/executors/data_backend.h @@ -18,8 +18,8 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -28,19 +28,20 @@ namespace tensorflow_federated { // the `DataExecutor`. class DataBackend { public: - // Resolves a `tensorflow_federated::v0::Data` object to a concrete + // Resolves a `federated_language::Data` object to a concrete // `tensorflow_federated::v0::Value` proto, writing the result to `value_out`. // // This function must be safe to call concurrently from multiple threads. - virtual absl::Status ResolveToValue(const v0::Data& data_reference, - const v0::Type& data_type, - v0::Value& value_out) = 0; + virtual absl::Status ResolveToValue( + const federated_language::Data& data_reference, + const federated_language::Type& data_type, v0::Value& value_out) = 0; - // Resolves a `tensorflow_federated::v0::Data` object to a concrete + // Resolves a `federated_language::Data` object to a concrete // `tensorflow_federated::v0::Value` proto, returning the result as a new // proto object. - absl::StatusOr ResolveToValue(const v0::Data& data_reference, - const v0::Type& data_type) { + absl::StatusOr ResolveToValue( + const federated_language::Data& data_reference, + const federated_language::Type& data_type) { v0::Value out; TFF_TRY(ResolveToValue(data_reference, data_type, out)); return out; diff --git a/tensorflow_federated/cc/core/impl/executors/data_executor.cc b/tensorflow_federated/cc/core/impl/executors/data_executor.cc index 4648f4742f..5623f26cbd 100644 --- a/tensorflow_federated/cc/core/impl/executors/data_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/data_executor.cc @@ -55,8 +55,8 @@ class DataExecutor : public ExecutorBase { // Note: `value_pb` is copied here in order to ensure that it remains // available for the lifetime of the resolving thread. However, it should // be relatively small and inexpensive (currently just a URI). - v0::Data data = value_pb.computation().data(); - v0::Type data_type = value_pb.computation().type(); + federated_language::Data data = value_pb.computation().data(); + federated_language::Type data_type = value_pb.computation().type(); return ThreadRun([this, data = std::move(data), data_type = std::move(data_type), this_keepalive = diff --git a/tensorflow_federated/cc/core/impl/executors/data_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/data_executor_test.cc index 194b38964d..938dba38fd 100644 --- a/tensorflow_federated/cc/core/impl/executors/data_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/data_executor_test.cc @@ -22,14 +22,14 @@ limitations under the License #include "googlemock/include/gmock/gmock.h" #include "googletest/include/gtest/gtest.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/executor_test_base.h" #include "tensorflow_federated/cc/core/impl/executors/mock_data_backend.h" #include "tensorflow_federated/cc/core/impl/executors/mock_executor.h" #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -56,15 +56,15 @@ class DataExecutorTest : public ExecutorTestBase { TEST_F(DataExecutorTest, CreateValueResolvesData) { std::string uri = "some_data_uri"; - v0::Type data_type; - v0::TensorType* tensor_type = data_type.mutable_tensor(); - tensor_type->set_dtype(v0::DataType::DT_INT32); + federated_language::Type data_type; + federated_language::TensorType* tensor_type = data_type.mutable_tensor(); + tensor_type->set_dtype(federated_language::DataType::DT_INT32); tensor_type->mutable_dims()->Add(1); tensor_type->set_unknown_rank(false); v0::Value resolved_data_value = TensorV(22); mock_data_backend_->ExpectResolveToValue(uri, data_type, resolved_data_value); v0::Value unresolved_data_value; - v0::Computation* unresolved_data_computation = + federated_language::Computation* unresolved_data_computation = unresolved_data_value.mutable_computation(); unresolved_data_computation->mutable_data()->set_uri(uri); *unresolved_data_computation->mutable_type() = data_type; diff --git a/tensorflow_federated/cc/core/impl/executors/dataset_utils.cc b/tensorflow_federated/cc/core/impl/executors/dataset_utils.cc index 94ed35e842..a93d98a13d 100644 --- a/tensorflow_federated/cc/core/impl/executors/dataset_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/dataset_utils.cc @@ -22,6 +22,8 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/tstring.h" @@ -29,8 +31,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/dataset_from_tensor_structures.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" namespace tensorflow_federated { @@ -39,11 +39,11 @@ absl::StatusOr GraphDefTensorFromSequence( std::vector> tensor_structures; for (const v0::Value::Sequence::Element& element_pb : sequence_pb.element()) { std::vector tensors; - for (const v0::Array& array_pb : element_pb.flat_value()) { + for (const federated_language::Array& array_pb : element_pb.flat_value()) { // Repeated fields are used for strings and scalars to maintain // compatibility with TensorFlow. if (tensorflow_federated::IsScalar(array_pb.shape()) || - array_pb.dtype() == v0::DataType::DT_STRING) { + array_pb.dtype() == federated_language::DataType::DT_STRING) { tensors.push_back(TFF_TRY(TensorFromArray(array_pb))); } else { tensors.push_back(TFF_TRY(TensorFromArrayContent(array_pb))); diff --git a/tensorflow_federated/cc/core/impl/executors/dataset_utils.h b/tensorflow_federated/cc/core/impl/executors/dataset_utils.h index 2323dead8b..bcc26edd8b 100644 --- a/tensorflow_federated/cc/core/impl/executors/dataset_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/dataset_utils.h @@ -19,9 +19,9 @@ limitations under the License #include #include "absl/status/statusor.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { diff --git a/tensorflow_federated/cc/core/impl/executors/eager_computation.cc b/tensorflow_federated/cc/core/impl/executors/eager_computation.cc index afb89ff00e..dc25a645a3 100644 --- a/tensorflow_federated/cc/core/impl/executors/eager_computation.cc +++ b/tensorflow_federated/cc/core/impl/executors/eager_computation.cc @@ -207,10 +207,10 @@ void AddControlEdgeForInitOp(tensorflow::Graph* graph, } absl::Status PopulateBindingNames( - const v0::TensorFlow::Binding& binding, + const federated_language::TensorFlow::Binding& binding, std::vector& tensor_names_from_binding) { switch (binding.binding_case()) { - case v0::TensorFlow::Binding::kTensor: { + case federated_language::TensorFlow::Binding::kTensor: { if (!binding.tensor().has_tensor_name()) { return absl::InternalError("Tensor binding does not have a name."); } @@ -219,13 +219,13 @@ absl::Status PopulateBindingNames( GetNodeName(binding.tensor().tensor_name())); break; } - case v0::TensorFlow::Binding::kStruct: { + case federated_language::TensorFlow::Binding::kStruct: { for (const auto& b : binding.struct_().element()) { TFF_TRY(PopulateBindingNames(b, tensor_names_from_binding)); } break; } - case v0::TensorFlow::Binding::kSequence: { + case federated_language::TensorFlow::Binding::kSequence: { return absl::UnimplementedError( "Only Struct and Tensor Binding support added"); } @@ -238,8 +238,8 @@ absl::Status PopulateBindingNames( absl::StatusOr ConvertToFunctionDef( std::string init_op, const tensorflow::GraphDef& graphdef_pb, - const v0::TensorFlow::Binding& input_binding, - const v0::TensorFlow::Binding& output_binding) { + const federated_language::TensorFlow::Binding& input_binding, + const federated_language::TensorFlow::Binding& output_binding) { tensorflow::FunctionDef func_def; std::vector visited(graphdef_pb.node_size()); std::deque queue; @@ -329,7 +329,7 @@ void UpdateVarHandleOpNodesAsAnonymous(tensorflow::FunctionDef& func_def) { } // namespace absl::StatusOr EagerComputation::FromProto( - const v0::TensorFlow& comp_pb) { + const federated_language::TensorFlow& comp_pb) { if (!(comp_pb.graph_def().Is())) { return absl::InvalidArgumentError(absl::StrCat( "Unsupported type in Graph def proto: ", comp_pb.graph_def().type_url(), diff --git a/tensorflow_federated/cc/core/impl/executors/eager_computation.h b/tensorflow_federated/cc/core/impl/executors/eager_computation.h index 34e7afe016..2585a2ca5e 100644 --- a/tensorflow_federated/cc/core/impl/executors/eager_computation.h +++ b/tensorflow_federated/cc/core/impl/executors/eager_computation.h @@ -24,10 +24,10 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/framework/function.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" namespace tensorflow_federated { // Class responsible for converting a GraphDef into FunctionDef and executing it @@ -45,7 +45,7 @@ class EagerComputation { // If non-empty Layout map is passed, a Relayout op is inserted after each // VarHandleOp node which has sharding spec specified. static absl::StatusOr FromProto( - const v0::TensorFlow& comp_pb); + const federated_language::TensorFlow& comp_pb); EagerComputation( tensorflow::FunctionDef main_function_def, diff --git a/tensorflow_federated/cc/core/impl/executors/eager_computation_test.cc b/tensorflow_federated/cc/core/impl/executors/eager_computation_test.cc index 34630e7f73..09c3d05feb 100644 --- a/tensorflow_federated/cc/core/impl/executors/eager_computation_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/eager_computation_test.cc @@ -57,16 +57,16 @@ namespace { // TODO: b/256948367 - Move these common methods to a base test utility file. template -inline v0::TensorFlow::Binding TensorB(const TfOp& op) { +inline federated_language::TensorFlow::Binding TensorB(const TfOp& op) { const tensorflow::Node* node = op.node(); - v0::TensorFlow::Binding binding; + federated_language::TensorFlow::Binding binding; *binding.mutable_tensor()->mutable_tensor_name() = node->name(); return binding; } -inline v0::TensorFlow::Binding StructB( - const absl::Span elements) { - v0::TensorFlow::Binding binding; +inline federated_language::TensorFlow::Binding StructB( + const absl::Span elements) { + federated_language::TensorFlow::Binding binding; auto struct_mut = binding.mutable_struct_(); for (const auto& element : elements) { *struct_mut->add_element() = element; @@ -74,14 +74,14 @@ inline v0::TensorFlow::Binding StructB( return binding; } -inline v0::Computation ComputationV( +inline federated_language::Computation ComputationV( const tensorflow::Scope& scope, - std::optional in_binding, - v0::TensorFlow::Binding out_binding, + std::optional in_binding, + federated_language::TensorFlow::Binding out_binding, const std::optional& init_op = std::nullopt, const std::vector function_defs = {}) { - v0::Computation comp_pb; - v0::TensorFlow* tensorflow_pb = comp_pb.mutable_tensorflow(); + federated_language::Computation comp_pb; + federated_language::TensorFlow* tensorflow_pb = comp_pb.mutable_tensorflow(); tensorflow::GraphDef graphdef_pb; @@ -418,8 +418,8 @@ TEST_F(EagerComputationTest, CallAddGraphDefWithFunctionDef) { } TEST_F(EagerComputationTest, InvalidComputationProto) { - v0::Computation comp_pb; - v0::TensorFlow* tensorflow_pb = comp_pb.mutable_tensorflow(); + federated_language::Computation comp_pb; + federated_language::TensorFlow* tensorflow_pb = comp_pb.mutable_tensorflow(); tensorflow::TensorProto tensor_pb; tensorflow_pb->mutable_graph_def()->PackFrom(tensor_pb); diff --git a/tensorflow_federated/cc/core/impl/executors/executor.h b/tensorflow_federated/cc/core/impl/executors/executor.h index 6d02fe03b0..d557cb924c 100644 --- a/tensorflow_federated/cc/core/impl/executors/executor.h +++ b/tensorflow_federated/cc/core/impl/executors/executor.h @@ -36,9 +36,9 @@ limitations under the License #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "absl/utility/utility.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { diff --git a/tensorflow_federated/cc/core/impl/executors/executor_bindings.cc b/tensorflow_federated/cc/core/impl/executors/executor_bindings.cc index 9676f8d63a..e2e79bba3f 100644 --- a/tensorflow_federated/cc/core/impl/executors/executor_bindings.cc +++ b/tensorflow_federated/cc/core/impl/executors/executor_bindings.cc @@ -34,6 +34,7 @@ limitations under the License #include "include/grpcpp/impl/channel_interface.h" #include "include/grpcpp/security/credentials.h" #include "include/grpcpp/support/channel_arguments.h" +#include "federated_language/proto/computation.pb.h" #include "include/pybind11/cast.h" #include "include/pybind11/detail/common.h" #include "include/pybind11/pybind11.h" @@ -61,7 +62,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/streaming_remote_executor.h" #include "tensorflow_federated/cc/core/impl/executors/tensorflow_executor.h" #include "tensorflow_federated/cc/core/impl/executors/xla_executor.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow { diff --git a/tensorflow_federated/cc/core/impl/executors/executor_service.cc b/tensorflow_federated/cc/core/impl/executors/executor_service.cc index 44c9113e62..089fe1e4b0 100644 --- a/tensorflow_federated/cc/core/impl/executors/executor_service.cc +++ b/tensorflow_federated/cc/core/impl/executors/executor_service.cc @@ -31,10 +31,10 @@ limitations under the License #include "absl/synchronization/mutex.h" #include "include/grpcpp/server_context.h" #include "include/grpcpp/support/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/status_conversion.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { diff --git a/tensorflow_federated/cc/core/impl/executors/executor_service.h b/tensorflow_federated/cc/core/impl/executors/executor_service.h index 7496aeae50..f8d936300e 100644 --- a/tensorflow_federated/cc/core/impl/executors/executor_service.h +++ b/tensorflow_federated/cc/core/impl/executors/executor_service.h @@ -30,10 +30,10 @@ limitations under the License #include "absl/synchronization/mutex.h" #include "include/grpcpp/grpcpp.h" #include "include/grpcpp/support/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/status_conversion.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.grpc.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" diff --git a/tensorflow_federated/cc/core/impl/executors/executor_service_test.cc b/tensorflow_federated/cc/core/impl/executors/executor_service_test.cc index d17abc143d..3cdb86c348 100644 --- a/tensorflow_federated/cc/core/impl/executors/executor_service_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/executor_service_test.cc @@ -30,6 +30,7 @@ limitations under the License #include "absl/types/span.h" #include "include/grpcpp/grpcpp.h" #include "include/grpcpp/support/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/mock_executor.h" @@ -37,7 +38,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { diff --git a/tensorflow_federated/cc/core/impl/executors/federating_executor.cc b/tensorflow_federated/cc/core/impl/executors/federating_executor.cc index 7b8f10b939..cfd253e1d6 100644 --- a/tensorflow_federated/cc/core/impl/executors/federating_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/federating_executor.cc @@ -32,6 +32,7 @@ limitations under the License #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/macros.h" @@ -42,7 +43,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" #include "tensorflow_federated/cc/core/impl/executors/value_validation.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -787,7 +787,8 @@ class FederatingExecutor : public ExecutorBase { switch (value.type()) { case ExecutorValue::ValueType::CLIENTS: { v0::Value_Federated* federated_pb = value_pb->mutable_federated(); - v0::FederatedType* type_pb = federated_pb->mutable_type(); + federated_language::FederatedType* type_pb = + federated_pb->mutable_type(); // All-equal-ness is not stored, so must be assumed to be false. // If the Python type system expects the value to be all-equal, it can // simply extract the first element in the list. @@ -814,7 +815,8 @@ class FederatingExecutor : public ExecutorBase { } case ExecutorValue::ValueType::SERVER: { v0::Value_Federated* federated_pb = value_pb->mutable_federated(); - v0::FederatedType* type_pb = federated_pb->mutable_type(); + federated_language::FederatedType* type_pb = + federated_pb->mutable_type(); // Server placement is assumed to be of cardinality one, and so must be // all-equal. type_pb->set_all_equal(true); diff --git a/tensorflow_federated/cc/core/impl/executors/mock_data_backend.h b/tensorflow_federated/cc/core/impl/executors/mock_data_backend.h index 24eafd6dc6..4d0cdfdbda 100644 --- a/tensorflow_federated/cc/core/impl/executors/mock_data_backend.h +++ b/tensorflow_federated/cc/core/impl/executors/mock_data_backend.h @@ -21,9 +21,9 @@ limitations under the License #include "googlemock/include/gmock/gmock.h" #include "absl/status/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/data_backend.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -37,14 +37,15 @@ class MockDataBackend : public DataBackend { public: ~MockDataBackend() override = default; MOCK_METHOD(absl::Status, ResolveToValue, - (const v0::Data& data_reference, const v0::Type& type_reference, + (const federated_language::Data& data_reference, + const federated_language::Type& type_reference, v0::Value& data_out), (override)); inline void ExpectResolveToValue(std::string expected_uri, - v0::Type expected_type, + federated_language::Type expected_type, v0::Value to_return) { - v0::Data data; + federated_language::Data data; data.set_uri(std::move(expected_uri)); EXPECT_CALL(*this, ResolveToValue(EqualsProto(data), EqualsProto(expected_type), ::testing::_)) diff --git a/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.cc b/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.cc index ad4de0525d..a8d9a97886 100644 --- a/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.cc @@ -30,9 +30,9 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -55,7 +55,8 @@ using NamedValue = std::tuple>; // References within the lambda will be resolved using the attached scope. class ScopedLambda { public: - explicit ScopedLambda(v0::Lambda lambda_pb, std::shared_ptr scope) + explicit ScopedLambda(federated_language::Lambda lambda_pb, + std::shared_ptr scope) : lambda_pb_(std::move(lambda_pb)), scope_(std::move(scope)) {} ScopedLambda(ScopedLambda&& other) : lambda_pb_(std::move(other.lambda_pb_)), @@ -72,7 +73,7 @@ class ScopedLambda { } private: - v0::Lambda lambda_pb_; + federated_language::Lambda lambda_pb_; std::shared_ptr scope_; }; @@ -197,7 +198,7 @@ class ReferenceResolvingExecutor // etc. The method delegates to other Evaluate*() methods, and the result // depends on the type of computation being evaluated. absl::StatusOr> Evaluate( - const v0::Computation& computation_pb, + const federated_language::Computation& computation_pb, const std::shared_ptr& scope) const; protected: @@ -242,51 +243,55 @@ class ReferenceResolvingExecutor // Evaluates a block. // // The semantics of a block are documented on the - // `tensorflow_federated::v0::Block` message defined in + // `federated_language::Block` message defined in // tensorflow_federated/proto/v0/computation.proto absl::StatusOr> EvaluateBlock( - const v0::Block& block_pb, const std::shared_ptr& scope) const; + const federated_language::Block& block_pb, + const std::shared_ptr& scope) const; // Evaluates a reference. // // The semantics of a reference are documented on the - // `tensorflow_federated::v0::Reference` message defined in + // `federated_language::Reference` message defined in // tensorflow_federated/proto/v0/computation.proto absl::StatusOr> EvaluateReference( - const v0::Reference& reference_pb, + const federated_language::Reference& reference_pb, const std::shared_ptr& scope) const; // Evaluates a lambda. // // The semantics of a reference are documented on the - // `tensorflow_federated::v0::Lambda` message defined in + // `federated_language::Lambda` message defined in // tensorflow_federated/proto/v0/computation.proto absl::StatusOr> EvaluateLambda( - const v0::Lambda& lambda_pb, const std::shared_ptr& scope) const; + const federated_language::Lambda& lambda_pb, + const std::shared_ptr& scope) const; // Evaluates a call. // // The semantics of a reference are documented on the - // `tensorflow_federated::v0::Call` message defined in + // `federated_language::Call` message defined in // tensorflow_federated/proto/v0/computation.proto absl::StatusOr> EvaluateCall( - const v0::Call& call_pb, const std::shared_ptr& scope) const; + const federated_language::Call& call_pb, + const std::shared_ptr& scope) const; // Evaluates a struct. // // The semantics of a struct are documented on the - // `tensorflow_federated::v0::Struct` message defined in + // `federated_language::Struct` message defined in // tensorflow_federated/proto/v0/computation.proto absl::StatusOr> EvaluateStruct( - const v0::Struct& struct_pb, const std::shared_ptr& scope) const; + const federated_language::Struct& struct_pb, + const std::shared_ptr& scope) const; // Evaluates a selection. // // The semantics of a selection are documented on the - // `tensorflow_federated::v0::Selection` message defined in + // `federated_language::Selection` message defined in // tensorflow_federated/proto/v0/computation.proto absl::StatusOr> EvaluateSelection( - const v0::Selection& selection_pb, + const federated_language::Selection& selection_pb, const std::shared_ptr& scope) const; }; @@ -505,42 +510,42 @@ absl::StatusOr ReferenceResolvingExecutor::Embed( absl::StatusOr> ReferenceResolvingExecutor::Evaluate( - const v0::Computation& computation_pb, + const federated_language::Computation& computation_pb, const std::shared_ptr& scope) const { switch (computation_pb.computation_case()) { - case v0::Computation::kTensorflow: - case v0::Computation::kIntrinsic: - case v0::Computation::kData: - case v0::Computation::kPlacement: - case v0::Computation::kLiteral: - case v0::Computation::kXla: { + case federated_language::Computation::kTensorflow: + case federated_language::Computation::kIntrinsic: + case federated_language::Computation::kData: + case federated_language::Computation::kPlacement: + case federated_language::Computation::kLiteral: + case federated_language::Computation::kXla: { // Note: we're copying the Computation proto here, possibly a TensorFlow // graph which might have large constants, possibly making it expensive. // However, we've taken this approach because we don't always have a - // `Value` for each `Computation` proto (see `v0::Block::local`); this - // code is simpler and more homogenous. If profiling shows this is a - // hotspot we can optimize. + // `Value` for each `Computation` proto (see + // `federated_language::Block::local`); this code is simpler and more + // homogenous. If profiling shows this is a hotspot we can optimize. v0::Value child_value_pb; *child_value_pb.mutable_computation() = computation_pb; return std::make_shared( TFF_TRY(child_executor_->CreateValue(child_value_pb))); } - case v0::Computation::kReference: { + case federated_language::Computation::kReference: { return EvaluateReference(computation_pb.reference(), scope); } - case v0::Computation::kBlock: { + case federated_language::Computation::kBlock: { return EvaluateBlock(computation_pb.block(), scope); } - case v0::Computation::kLambda: { + case federated_language::Computation::kLambda: { return EvaluateLambda(computation_pb.lambda(), scope); } - case v0::Computation::kCall: { + case federated_language::Computation::kCall: { return EvaluateCall(computation_pb.call(), scope); } - case v0::Computation::kStruct: { + case federated_language::Computation::kStruct: { return EvaluateStruct(computation_pb.struct_(), scope); } - case v0::Computation::kSelection: { + case federated_language::Computation::kSelection: { return EvaluateSelection(computation_pb.selection(), scope); } default: @@ -552,14 +557,15 @@ ReferenceResolvingExecutor::Evaluate( absl::StatusOr> ReferenceResolvingExecutor::EvaluateBlock( - const v0::Block& block_pb, const std::shared_ptr& scope) const { + const federated_language::Block& block_pb, + const std::shared_ptr& scope) const { std::shared_ptr current_scope = scope; - auto local_pb_formatter = [](std::string* out, - const v0::Block::Local& local_pb) { - out->append(local_pb.name()); - }; + auto local_pb_formatter = + [](std::string* out, const federated_language::Block::Local& local_pb) { + out->append(local_pb.name()); + }; for (int i = 0; i < block_pb.local_size(); ++i) { - const v0::Block::Local& local_pb = block_pb.local(i); + const federated_language::Block::Local& local_pb = block_pb.local(i); std::shared_ptr value = TFF_TRY( Evaluate(local_pb.value(), current_scope), absl::StrCat( @@ -574,7 +580,7 @@ ReferenceResolvingExecutor::EvaluateBlock( absl::StatusOr> ReferenceResolvingExecutor::EvaluateReference( - const v0::Reference& reference_pb, + const federated_language::Reference& reference_pb, const std::shared_ptr& scope) const { std::shared_ptr resolved_value = TFF_TRY(scope->Resolve(reference_pb.name()), @@ -589,13 +595,15 @@ ReferenceResolvingExecutor::EvaluateReference( absl::StatusOr> ReferenceResolvingExecutor::EvaluateLambda( - const v0::Lambda& lambda_pb, const std::shared_ptr& scope) const { + const federated_language::Lambda& lambda_pb, + const std::shared_ptr& scope) const { return std::make_shared(ScopedLambda{lambda_pb, scope}); } absl::StatusOr> ReferenceResolvingExecutor::EvaluateCall( - const v0::Call& call_pb, const std::shared_ptr& scope) const { + const federated_language::Call& call_pb, + const std::shared_ptr& scope) const { std::shared_ptr function = TFF_TRY(Evaluate(call_pb.function(), scope)); std::optional> argument; @@ -607,10 +615,12 @@ ReferenceResolvingExecutor::EvaluateCall( absl::StatusOr> ReferenceResolvingExecutor::EvaluateStruct( - const v0::Struct& struct_pb, const std::shared_ptr& scope) const { + const federated_language::Struct& struct_pb, + const std::shared_ptr& scope) const { std::vector> elements; elements.reserve(struct_pb.element_size()); - for (const v0::Struct::Element& element_pb : struct_pb.element()) { + for (const federated_language::Struct::Element& element_pb : + struct_pb.element()) { elements.emplace_back(TFF_TRY(Evaluate(element_pb.value(), scope))); } return std::make_shared(std::move(elements)); @@ -618,7 +628,7 @@ ReferenceResolvingExecutor::EvaluateStruct( absl::StatusOr> ReferenceResolvingExecutor::EvaluateSelection( - const v0::Selection& selection_pb, + const federated_language::Selection& selection_pb, const std::shared_ptr& scope) const { return CreateSelectionInternal( TFF_TRY(Evaluate(selection_pb.source(), scope)), selection_pb.index()); diff --git a/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor_test.cc index 35c62e365e..210b72905e 100644 --- a/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/reference_resolving_executor_test.cc @@ -17,12 +17,12 @@ limitations under the License // // IMPORTANT: many of the `v0::Value` protocol buffer messages used in the unit // tests in this file are not well-formed from the view of the entire execution -// stack. Particularly `v0::Computation` message fields that are not used by -// the ReferenceResolvingExecutor are often ellided to assert that they are not -// dependend on. This generally means the test protos are only valid because the -// child executor is mocked out and returns a hardcoded result, and should not -// be used a reference for how a real `v0::Computation` protocol buffer message -// should look. +// stack. Particularly `federated_language::Computation` message fields that +// are not used by the ReferenceResolvingExecutor are often ellided to assert +// that they are not dependend on. This generally means the test protos are only +// valid because the child executor is mocked out and returns a hardcoded +// result, and should not be used a reference for how a real +// `federated_language::Computation` protocol buffer message should look. #include "tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.h" @@ -42,6 +42,8 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/math_ops.h" @@ -55,8 +57,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -97,16 +97,18 @@ MATCHER_P(HasValueId, expected_id, // Constructs simple graphs for testing Tensorflow backend computations. inline v0::Value NoArgConstantTfComputationV() { v0::Value value_pb; - v0::Computation* computation_pb = value_pb.mutable_computation(); + federated_language::Computation* computation_pb = + value_pb.mutable_computation(); // Build the graph. tensorflow::Scope root = tensorflow::Scope::NewRootScope(); tf::ops::OnesLike ones(root, tf::Tensor(1.0)); tensorflow::GraphDef graphdef_pb; QCHECK_OK(root.ToGraphDef(&graphdef_pb)); - v0::TensorFlow* tensorflow_pb = computation_pb->mutable_tensorflow(); + federated_language::TensorFlow* tensorflow_pb = + computation_pb->mutable_tensorflow(); tensorflow_pb->mutable_graph_def()->PackFrom(graphdef_pb); // Build the tensor bindings. - v0::TensorFlow::TensorBinding* result_binding_pb = + federated_language::TensorFlow::TensorBinding* result_binding_pb = tensorflow_pb->mutable_result()->mutable_tensor(); result_binding_pb->set_tensor_name(ones.node()->name()); return value_pb; @@ -114,21 +116,24 @@ inline v0::Value NoArgConstantTfComputationV() { inline v0::Value UnarySquareTfComputationV() { v0::Value value_pb; - v0::Computation* computation_pb = value_pb.mutable_computation(); + federated_language::Computation* computation_pb = + value_pb.mutable_computation(); // Build the graph. tensorflow::Scope root = tensorflow::Scope::NewRootScope(); tf::ops::Placeholder x(root, tf::DT_FLOAT); tf::ops::Square square(root, x); tensorflow::GraphDef graphdef_pb; QCHECK_OK(root.ToGraphDef(&graphdef_pb)); - v0::TensorFlow* tensorflow_pb = computation_pb->mutable_tensorflow(); + federated_language::TensorFlow* tensorflow_pb = + computation_pb->mutable_tensorflow(); tensorflow_pb->mutable_graph_def()->PackFrom(graphdef_pb); // Build the tensor bindings. - v0::TensorFlow::StructBinding* struct_paramter_pb = + federated_language::TensorFlow::StructBinding* struct_paramter_pb = tensorflow_pb->mutable_parameter()->mutable_struct_(); - v0::TensorFlow::Binding* x_binding_pb = struct_paramter_pb->add_element(); + federated_language::TensorFlow::Binding* x_binding_pb = + struct_paramter_pb->add_element(); x_binding_pb->mutable_tensor()->set_tensor_name(x.node()->name()); - v0::TensorFlow::TensorBinding* result_binding_pb = + federated_language::TensorFlow::TensorBinding* result_binding_pb = tensorflow_pb->mutable_result()->mutable_tensor(); result_binding_pb->set_tensor_name(square.node()->name()); return value_pb; @@ -136,7 +141,8 @@ inline v0::Value UnarySquareTfComputationV() { inline v0::Value BinaryAddTfComputationV() { v0::Value value_pb; - v0::Computation* computation_pb = value_pb.mutable_computation(); + federated_language::Computation* computation_pb = + value_pb.mutable_computation(); // Build the graph. tensorflow::Scope root = tensorflow::Scope::NewRootScope(); tf::ops::Placeholder x(root, tf::DT_FLOAT); @@ -144,16 +150,19 @@ inline v0::Value BinaryAddTfComputationV() { tf::ops::AddV2 sum(root, x, y); tensorflow::GraphDef graphdef_pb; QCHECK_OK(root.ToGraphDef(&graphdef_pb)); - v0::TensorFlow* tensorflow_pb = computation_pb->mutable_tensorflow(); + federated_language::TensorFlow* tensorflow_pb = + computation_pb->mutable_tensorflow(); tensorflow_pb->mutable_graph_def()->PackFrom(graphdef_pb); // Build the tensor bindings. - v0::TensorFlow::StructBinding* struct_paramter_pb = + federated_language::TensorFlow::StructBinding* struct_paramter_pb = tensorflow_pb->mutable_parameter()->mutable_struct_(); - v0::TensorFlow::Binding* x_binding_pb = struct_paramter_pb->add_element(); + federated_language::TensorFlow::Binding* x_binding_pb = + struct_paramter_pb->add_element(); x_binding_pb->mutable_tensor()->set_tensor_name(x.node()->name()); - v0::TensorFlow::Binding* y_binding_pb = struct_paramter_pb->add_element(); + federated_language::TensorFlow::Binding* y_binding_pb = + struct_paramter_pb->add_element(); y_binding_pb->mutable_tensor()->set_tensor_name(y.node()->name()); - v0::TensorFlow::TensorBinding* result_binding_pb = + federated_language::TensorFlow::TensorBinding* result_binding_pb = tensorflow_pb->mutable_result()->mutable_tensor(); result_binding_pb->set_tensor_name(sum.node()->name()); return value_pb; @@ -203,11 +212,12 @@ TEST_F(ReferenceResolvingExecutorTest, CreateValueSequence) { TEST_F(ReferenceResolvingExecutorTest, CreateValueFederatedTensor) { v0::Value federated_value_pb; v0::Value::Federated* federated_pb = federated_value_pb.mutable_federated(); - v0::FederatedType* type_pb = federated_pb->mutable_type(); + federated_language::FederatedType* type_pb = federated_pb->mutable_type(); type_pb->set_all_equal(false); type_pb->mutable_placement()->mutable_value()->set_uri(kTestPlacement); - v0::TensorType* tensor_type = type_pb->mutable_member()->mutable_tensor(); - tensor_type->set_dtype(v0::DataType::DT_FLOAT); + federated_language::TensorType* tensor_type = + type_pb->mutable_member()->mutable_tensor(); + tensor_type->set_dtype(federated_language::DataType::DT_FLOAT); constexpr int kNumClients = 3; for (int i = 0; i < kNumClients; ++i) { *federated_pb->add_value() = TensorV(i); @@ -249,15 +259,17 @@ TEST_F(ReferenceResolvingExecutorTest, CreateValueNestedStructOfTensor) { TEST_F(ReferenceResolvingExecutorTest, CreateValueFederatedStructOfTensor) { v0::Value federated_value_pb; v0::Value::Federated* federated_pb = federated_value_pb.mutable_federated(); - v0::FederatedType* type_pb = federated_pb->mutable_type(); + federated_language::FederatedType* type_pb = federated_pb->mutable_type(); type_pb->set_all_equal(false); type_pb->mutable_placement()->mutable_value()->set_uri(kTestPlacement); - v0::StructType* struct_type = type_pb->mutable_member()->mutable_struct_(); + federated_language::StructType* struct_type = + type_pb->mutable_member()->mutable_struct_(); constexpr int kNumFields = 3; for (int i = 0; i < kNumFields; ++i) { - v0::StructType::Element* element_pb = struct_type->add_element(); + federated_language::StructType::Element* element_pb = + struct_type->add_element(); element_pb->mutable_value()->mutable_tensor()->set_dtype( - v0::DataType::DT_FLOAT); + federated_language::DataType::DT_FLOAT); } constexpr int kNumClients = 3; for (int i = 0; i < kNumClients; ++i) { @@ -308,7 +320,7 @@ TEST_F(ReferenceResolvingExecutorTest, CreateValueComputationPlacement) { } TEST_F(ReferenceResolvingExecutorTest, NoArgLambda) { - v0::Computation data_pb = DataComputation("test_data_uri"); + federated_language::Computation data_pb = DataComputation("test_data_uri"); v0::Value lambda_pb = ComputationV(LambdaComputation(std::nullopt, data_pb)); auto create_result = test_executor_->CreateValue(lambda_pb); EXPECT_THAT(create_result, IsOkAndHolds(HasValueId(0))); @@ -383,7 +395,7 @@ TEST_F(ReferenceResolvingExecutorTest, LambdaStructArgumentLazilyEmbedded) { TEST_F(ReferenceResolvingExecutorTest, LambdaArgumentScopeHidesBlockNamedValue) { - v0::Computation data_pb = DataComputation("test_data_uri"); + federated_language::Computation data_pb = DataComputation("test_data_uri"); v0::Value lambda_pb = ComputationV(BlockComputation( {{"test_arg", data_pb}, {"test_lambda", @@ -471,7 +483,7 @@ TEST_F(ReferenceResolvingExecutorTest, {"test_ref2", DataComputation("test_data_uri2")}}, ReferenceComputation("test_ref"))); ValueId child_id = 3; - for (const v0::Block::Local& local_pb : + for (const federated_language::Block::Local& local_pb : block_value_pb.computation().block().local()) { EXPECT_CALL(*mock_executor_, Dispose(child_id)); EXPECT_CALL(*mock_executor_, @@ -495,7 +507,7 @@ TEST_F(ReferenceResolvingExecutorTest, {"test_ref", DataComputation("test_data_uri2")}}, ReferenceComputation("test_ref"))); ValueId child_id = 3; - for (const v0::Block::Local& local_pb : + for (const federated_language::Block::Local& local_pb : block_value_pb.computation().block().local()) { EXPECT_CALL(*mock_executor_, Dispose(child_id)); EXPECT_CALL(*mock_executor_, diff --git a/tensorflow_federated/cc/core/impl/executors/remote_executor.cc b/tensorflow_federated/cc/core/impl/executors/remote_executor.cc index e721df6463..1e8fca8e16 100644 --- a/tensorflow_federated/cc/core/impl/executors/remote_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/remote_executor.cc @@ -30,12 +30,12 @@ limitations under the License #include "absl/synchronization/mutex.h" #include "include/grpcpp/grpcpp.h" #include "include/grpcpp/support/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/status_conversion.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.grpc.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" @@ -147,11 +147,11 @@ class ExecutorValue { } const v0::ValueRef& Get() const { return value_ref_; } - const v0::Type& Type() const { return type_pb_; } + const federated_language::Type& Type() const { return type_pb_; } private: const v0::ValueRef value_ref_; - const v0::Type type_pb_; + const federated_language::Type type_pb_; const v0::ExecutorId executor_pb_; std::shared_ptr stub_; }; @@ -164,7 +164,7 @@ absl::Status RemoteExecutor::EnsureInitialized() { v0::GetExecutorRequest request; for (auto iter = cardinalities_.begin(); iter != cardinalities_.end(); ++iter) { - v0::Placement placement; + federated_language::Placement placement; placement.set_uri(iter->first); v0::Cardinality cardinality; *cardinality.mutable_placement() = placement; @@ -241,7 +241,7 @@ absl::StatusOr RemoteExecutor::CreateStruct( grpc::ClientContext context; std::vector> values = TFF_TRY(WaitAll(futures)); - v0::Type result_type; + federated_language::Type result_type; for (const std::shared_ptr& element : values) { v0::CreateStructRequest_Element struct_elem; *struct_elem.mutable_value_ref() = element->Get(); diff --git a/tensorflow_federated/cc/core/impl/executors/sequence_executor.cc b/tensorflow_federated/cc/core/impl/executors/sequence_executor.cc index 61a633c95b..b3ed8a9e5f 100644 --- a/tensorflow_federated/cc/core/impl/executors/sequence_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/sequence_executor.cc @@ -33,6 +33,7 @@ limitations under the License #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" @@ -43,7 +44,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/struct_traversal_order.h" #include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -68,12 +68,13 @@ class SequenceIterator { // Computes the number of tensors in a nested tensor type, returning an error // status if a type other than tensor or structure is encountered. -absl::StatusOr NumTensorsInType(const v0::Type& type) { +absl::StatusOr NumTensorsInType( + const federated_language::Type& type) { switch (type.type_case()) { - case v0::Type::kTensor: { + case federated_language::Type::kTensor: { return 1; } - case v0::Type::kStruct: { + case federated_language::Type::kStruct: { uint32_t total_count = 0; for (const auto& el_type : type.struct_().element()) { total_count += TFF_TRY(NumTensorsInType(el_type.value())); @@ -90,9 +91,9 @@ absl::StatusOr NumTensorsInType(const v0::Type& type) { absl::StatusOr EmbedTensorsAsType( const absl::Span tensors, - Executor& target_executor, const v0::Type& type) { + Executor& target_executor, const federated_language::Type& type) { switch (type.type_case()) { - case v0::Type::kTensor: { + case federated_language::Type::kTensor: { if (tensors.size() != 1) { return absl::InvalidArgumentError(absl::StrCat( "Attempted to embed a vector of tensors of length ", tensors.size(), @@ -104,7 +105,7 @@ absl::StatusOr EmbedTensorsAsType( TFF_TRY(SerializeTensorValue(tensors.at(0), &tensor_value)); return ShareValueId(TFF_TRY(target_executor.CreateValue(tensor_value))); } - case v0::Type::kStruct: { + case federated_language::Type::kStruct: { std::vector traversal_order = TFF_TRY(TFNestTraversalOrderFromStruct(type.struct_())); uint32_t next_element_index = 0; @@ -161,7 +162,7 @@ class DatasetIterator : public SequenceIterator { public: explicit DatasetIterator( std::unique_ptr iter, - v0::Type element_type) + federated_language::Type element_type) : ds_iterator_(std::move(iter)), element_type_(std::move(element_type)) {} ~DatasetIterator() final = default; @@ -185,7 +186,7 @@ class DatasetIterator : public SequenceIterator { private: DatasetIterator() = delete; std::unique_ptr ds_iterator_; - v0::Type element_type_; + federated_language::Type element_type_; }; class SequenceIterator; diff --git a/tensorflow_federated/cc/core/impl/executors/sequence_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/sequence_executor_test.cc index 5cc3e69e5d..01024ba465 100644 --- a/tensorflow_federated/cc/core/impl/executors/sequence_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/sequence_executor_test.cc @@ -24,13 +24,13 @@ limitations under the License #include "googlemock/include/gmock/gmock.h" #include "googletest/include/gtest/gtest.h" #include "absl/status/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/executor_test_base.h" #include "tensorflow_federated/cc/core/impl/executors/mock_executor.h" #include "tensorflow_federated/cc/core/impl/executors/sequence_intrinsics.h" #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" namespace tensorflow_federated { @@ -236,7 +236,7 @@ TEST_F(SequenceExecutorTest, EmbedFailsWithBadType) { v0::Value sequence_pb = SequenceV(1, dataset_len, 1); // We mutate the element type of this Sequence value to a non-embeddable type. - v0::Type function_type; + federated_language::Type function_type; *function_type.mutable_function()->mutable_result() = sequence_pb.sequence().element_type(); @@ -333,13 +333,13 @@ TEST_F(SequenceExecutorTest, CreateCallNestedStructureSequenceReduce) { // Notice that the names appear in non-sorted order in the TFF type signature; // we explicitly test this case to ensure that our traversal corresponds to // tf.nest's traversal order, where the keys of ordered dicts are sorted. - v0::Type sequence_element_type; + federated_language::Type sequence_element_type; *sequence_element_type.mutable_struct_()->add_element()->mutable_value() = MakeInt64ScalarType(); - v0::Type* nested_struct_type = + federated_language::Type* nested_struct_type = sequence_element_type.mutable_struct_()->add_element()->mutable_value(); for (int i = 0; i < 2; i++) { - v0::StructType_Element* struct_elem = + federated_language::StructType_Element* struct_elem = nested_struct_type->mutable_struct_()->add_element(); *struct_elem->mutable_value() = MakeInt64ScalarType(); if (i == 0) { diff --git a/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor.cc b/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor.cc index 53977677db..fcdc59dc85 100644 --- a/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor.cc @@ -38,6 +38,7 @@ limitations under the License #include "absl/synchronization/mutex.h" #include "include/grpcpp/grpcpp.h" #include "include/grpcpp/support/status.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/federated_intrinsics.h" @@ -45,7 +46,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" #include "tensorflow_federated/cc/core/impl/executors/type_utils.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.grpc.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" @@ -102,12 +102,13 @@ absl::Status BuildPlacedStructValue(const v0::Value::Struct& struct_value_pb, // the remote executor can track the resulting value, which is necessary to // later stream results during materialization. absl::StatusOr CreateFederatedZipComputation( - const v0::StructType& parameter_type_pb, - const v0::FederatedType& result_type_pb, - const v0::PlacementSpec& placement_spec) { + const federated_language::StructType& parameter_type_pb, + const federated_language::FederatedType& result_type_pb, + const federated_language::PlacementSpec& placement_spec) { v0::Value intrinsic_pb; - v0::Computation* computation_pb = intrinsic_pb.mutable_computation(); - v0::FunctionType* computation_type_pb = + federated_language::Computation* computation_pb = + intrinsic_pb.mutable_computation(); + federated_language::FunctionType* computation_type_pb = computation_pb->mutable_type()->mutable_function(); *computation_type_pb->mutable_parameter()->mutable_struct_() = parameter_type_pb; @@ -127,19 +128,20 @@ absl::StatusOr CreateFederatedZipComputation( return intrinsic_pb; } -v0::Call CreateCalledFederatedMappedSelection(absl::string_view intrinsic_uri, - absl::string_view arg_ref_name, - int32_t index) { - v0::Call call_pb; +federated_language::Call CreateCalledFederatedMappedSelection( + absl::string_view intrinsic_uri, absl::string_view arg_ref_name, + int32_t index) { + federated_language::Call call_pb; call_pb.mutable_function()->mutable_intrinsic()->set_uri( intrinsic_uri.data(), intrinsic_uri.size()); - v0::Struct* arg_struct = call_pb.mutable_argument()->mutable_struct_(); + federated_language::Struct* arg_struct = + call_pb.mutable_argument()->mutable_struct_(); - v0::Lambda* local_lambda_pb = + federated_language::Lambda* local_lambda_pb = arg_struct->add_element()->mutable_value()->mutable_lambda(); constexpr char kMapArg[] = "map_arg"; local_lambda_pb->set_parameter_name(kMapArg); - v0::Selection* selection_pb = + federated_language::Selection* selection_pb = local_lambda_pb->mutable_result()->mutable_selection(); selection_pb->mutable_source()->mutable_reference()->set_name(kMapArg); selection_pb->set_index(index); @@ -173,7 +175,7 @@ v0::Call CreateCalledFederatedMappedSelection(absl::string_view intrinsic_uri, // in // )) absl::StatusOr CreateSelectionFederatedStructComputation( - const v0::FederatedType& parameter_type_pb) { + const federated_language::FederatedType& parameter_type_pb) { if (!parameter_type_pb.member().has_struct_()) { // We don't want to create and send RPCs for computations that don't require // them, make this an error condition. @@ -182,18 +184,19 @@ absl::StatusOr CreateSelectionFederatedStructComputation( parameter_type_pb.ShortDebugString())); } v0::Value value_pb; - v0::FunctionType* lambda_type_pb = + federated_language::FunctionType* lambda_type_pb = value_pb.mutable_computation()->mutable_type()->mutable_function(); *lambda_type_pb->mutable_parameter()->mutable_federated() = parameter_type_pb; // NOTE: the result type will be computed as we build the computation and set // at the end of this method. - v0::StructType result_type_pb; - v0::FederatedType federated_type_template_pb; + federated_language::StructType result_type_pb; + federated_language::FederatedType federated_type_template_pb; *federated_type_template_pb.mutable_placement() = parameter_type_pb.placement(); federated_type_template_pb.set_all_equal(parameter_type_pb.all_equal()); - v0::Lambda* lambda_pb = value_pb.mutable_computation()->mutable_lambda(); + federated_language::Lambda* lambda_pb = + value_pb.mutable_computation()->mutable_lambda(); constexpr char kFederatedStructArg[] = "federated_struct_arg"; lambda_pb->set_parameter_name(kFederatedStructArg); @@ -211,8 +214,9 @@ absl::StatusOr CreateSelectionFederatedStructComputation( // A list of elements to iteratively process as the method descends into a // nested structure. We perform a breadth-first-traversal of the nested // structure. - std::list> + std::list> structs_to_process = {{ lambda_pb->mutable_result()->mutable_block(), kFederatedStructArg, @@ -225,14 +229,14 @@ absl::StatusOr CreateSelectionFederatedStructComputation( output_struct_type_pb] = structs_to_process.front(); structs_to_process.pop_front(); for (int32_t i = 0; i < parent_struct_type_pb->element_size(); ++i) { - const v0::StructType::Element& element_type_pb = + const federated_language::StructType::Element& element_type_pb = parent_struct_type_pb->element(i); - v0::StructType::Element* output_element_type_pb = + federated_language::StructType::Element* output_element_type_pb = output_struct_type_pb->add_element(); switch (element_type_pb.value().type_case()) { - case v0::Type::kTensor: - case v0::Type::kSequence: { - v0::Block::Local* local_pb = block_pb->add_local(); + case federated_language::Type::kTensor: + case federated_language::Type::kSequence: { + federated_language::Block::Local* local_pb = block_pb->add_local(); local_pb->set_name(absl::StrCat("elem_", i)); *local_pb->mutable_value()->mutable_call() = CreateCalledFederatedMappedSelection(intrinsic_uri, @@ -244,19 +248,21 @@ absl::StatusOr CreateSelectionFederatedStructComputation( ->mutable_member() = element_type_pb.value(); break; } - case v0::Type::kStruct: { + case federated_language::Type::kStruct: { // Add a local to select the nested structure, and give it a name with // the selection path. std::string nested_struct_ref_name = absl::StrCat("nested_struct_", i); - v0::Block::Local* nested_struct_local_pb = block_pb->add_local(); + federated_language::Block::Local* nested_struct_local_pb = + block_pb->add_local(); nested_struct_local_pb->set_name(nested_struct_ref_name); *nested_struct_local_pb->mutable_value()->mutable_call() = CreateCalledFederatedMappedSelection(intrinsic_uri, parent_ref_name, i); // We now need to descend into this struct, add it to the list to // process. - v0::Block::Local* nested_block_pb = block_pb->add_local(); + federated_language::Block::Local* nested_block_pb = + block_pb->add_local(); nested_block_pb->set_name(absl::StrCat("elem_", i)); structs_to_process.emplace_back( nested_block_pb->mutable_value()->mutable_block(), @@ -272,8 +278,9 @@ absl::StatusOr CreateSelectionFederatedStructComputation( } // After traversing all the elements in the current structure, gather the // elements from the locals that need to be part of the output structure. - v0::Struct* result_struct = block_pb->mutable_result()->mutable_struct_(); - for (const v0::Block::Local& local_pb : block_pb->local()) { + federated_language::Struct* result_struct = + block_pb->mutable_result()->mutable_struct_(); + for (const federated_language::Block::Local& local_pb : block_pb->local()) { // Only pickup the elements, not local selections, in the file output. if (absl::StartsWith(local_pb.name(), "elem_")) { result_struct->add_element() @@ -382,7 +389,7 @@ class StreamingRemoteExecutor : public ExecutorBase { // structure and stream them back. class ExecutorValue { public: - ExecutorValue(v0::ValueRef value_ref, v0::Type type_pb, + ExecutorValue(v0::ValueRef value_ref, federated_language::Type type_pb, v0::ExecutorId executor_pb, std::shared_ptr stub) : value_ref_(std::move(value_ref)), @@ -407,11 +414,11 @@ class ExecutorValue { } const v0::ValueRef& Get() const { return value_ref_; } - const v0::Type& Type() const { return type_pb_; } + const federated_language::Type& Type() const { return type_pb_; } private: const v0::ValueRef value_ref_; - const v0::Type type_pb_; + const federated_language::Type type_pb_; const v0::ExecutorId executor_pb_; std::shared_ptr stub_; }; @@ -424,7 +431,7 @@ absl::Status StreamingRemoteExecutor::EnsureInitialized() { v0::GetExecutorRequest request; for (auto iter = cardinalities_.begin(); iter != cardinalities_.end(); ++iter) { - v0::Placement placement; + federated_language::Placement placement; placement.set_uri(iter->first); v0::Cardinality cardinality; *cardinality.mutable_placement() = placement; @@ -455,7 +462,8 @@ StreamingRemoteExecutor::CreateExecutorFederatedValueStreaming( // "federated-structure-of-values" to "structure-of-federated-values" for // streaming across the RPC channel. At the end, a `federated_zip` intrisic // call will promote the values back to a `federated_structure_of_values`. - const v0::PlacementSpec& placement_spec = federated_pb.type().placement(); + const federated_language::PlacementSpec& placement_spec = + federated_pb.type().placement(); const bool all_equal = federated_pb.type().all_equal(); const int32_t struct_size = federated_pb.type().member().struct_().element_size(); @@ -465,14 +473,15 @@ StreamingRemoteExecutor::CreateExecutorFederatedValueStreaming( } // We build up a type for the intrinsic parameter for the federated_zip // computation that will be called after the streaming structure. - v0::StructType parameter_type_pb; + federated_language::StructType parameter_type_pb; std::vector elements; elements.reserve(struct_size); v0::Value element_pb; for (int32_t i = 0; i < struct_size; ++i) { element_pb.Clear(); v0::Value::Federated* federated_element_pb = element_pb.mutable_federated(); - v0::FederatedType* federated_type = federated_element_pb->mutable_type(); + federated_language::FederatedType* federated_type = + federated_element_pb->mutable_type(); *federated_type->mutable_placement() = placement_spec; federated_type->set_all_equal(all_equal); // Note: ignoring the `name()` of the elements. @@ -530,7 +539,7 @@ absl::StatusOr StreamingRemoteExecutor::CreateExecutorValue( absl::StatusOr StreamingRemoteExecutor::CreateValueRPC( const v0::Value& value_pb) { - v0::Type type_pb = TFF_TRY(InferTypeFromValue(value_pb)); + federated_language::Type type_pb = TFF_TRY(InferTypeFromValue(value_pb)); VLOG(5) << "CreateValueRPC: [" << type_pb.ShortDebugString() << "]"; if (type_pb.has_function() || type_pb.ShortDebugString().empty()) { VLOG(5) << value_pb.Utf8DebugString(); @@ -597,8 +606,8 @@ absl::StatusOr StreamingRemoteExecutor::CreateStruct( grpc::ClientContext context; std::vector> values = TFF_TRY(WaitAll(futures)); - v0::Type result_type; - v0::StructType* struct_type = result_type.mutable_struct_(); + federated_language::Type result_type; + federated_language::StructType* struct_type = result_type.mutable_struct_(); for (const std::shared_ptr& element : values) { v0::CreateStructRequest_Element struct_elem; *struct_elem.mutable_value_ref() = element->Get(); @@ -622,7 +631,7 @@ absl::StatusOr StreamingRemoteExecutor::CreateSelection( this_keepalive = shared_from_this()]() -> absl::StatusOr> { std::shared_ptr source_value = TFF_TRY(Wait(source)); - const v0::Type& source_type_pb = source_value->Type(); + const federated_language::Type& source_type_pb = source_value->Type(); if (!source_type_pb.has_struct_()) { return absl::InvalidArgumentError( absl::StrCat("Error selecting from non-Struct value: ", @@ -636,7 +645,7 @@ absl::StatusOr StreamingRemoteExecutor::CreateSelection( request.set_index(index); grpc::Status status = this->stub_->CreateSelection(&context, request, &response); - const v0::Type element_type_pb = + const federated_language::Type element_type_pb = source_type_pb.struct_().element(index).value(); TFF_TRY(grpc_to_absl(status)); return std::make_shared(std::move(response.value_ref()), @@ -649,11 +658,12 @@ absl::Status StreamingRemoteExecutor::Materialize(ValueFuture value, v0::Value* value_pb) { std::shared_ptr value_ref = TFF_TRY(Wait(value)); switch (value_ref->Type().type_case()) { - case v0::Type::kTensor: { + case federated_language::Type::kTensor: { return MaterializeRPC(value, value_pb); } - case v0::Type::kStruct: { - const v0::StructType& struct_type_pb = value_ref->Type().struct_(); + case federated_language::Type::kStruct: { + const federated_language::StructType& struct_type_pb = + value_ref->Type().struct_(); v0::Value::Struct* struct_value_pb = value_pb->mutable_struct_(); for (int32_t i = 0; i < struct_type_pb.element_size(); ++i) { ValueFuture selection = TFF_TRY(CreateSelection(value, i)); @@ -662,8 +672,9 @@ absl::Status StreamingRemoteExecutor::Materialize(ValueFuture value, } return absl::OkStatus(); } - case v0::Type::kFederated: { - const v0::Type& member_type_pb = value_ref->Type().federated().member(); + case federated_language::Type::kFederated: { + const federated_language::Type& member_type_pb = + value_ref->Type().federated().member(); if (!member_type_pb.has_struct_()) { // If not struct, nothing to stream; forward call as-is. return MaterializeRPC(value, value_pb); diff --git a/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor_test.cc index 98131d54a0..436cb80c1f 100644 --- a/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/streaming_remote_executor_test.cc @@ -39,6 +39,8 @@ limitations under the License #include "absl/time/time.h" #include "absl/types/span.h" #include "include/grpcpp/support/status.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" #include "tensorflow_federated/cc/core/impl/executors/federated_intrinsics.h" @@ -47,8 +49,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.grpc.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" @@ -62,7 +62,7 @@ using testing::proto::IgnoringRepeatedFieldOrdering; inline v0::Value ServerV(v0::Value unplaced_value) { v0::Value server_value = testing::ServerV(unplaced_value); - absl::StatusOr inferred_type_pb = + absl::StatusOr inferred_type_pb = InferTypeFromValue(unplaced_value); CHECK(inferred_type_pb.ok()) << inferred_type_pb.status(); *server_value.mutable_federated()->mutable_type()->mutable_member() = @@ -73,7 +73,7 @@ inline v0::Value ServerV(v0::Value unplaced_value) { inline v0::Value ClientsV(absl::Span unplaced_values) { v0::Value clients_value = testing::ClientsV(unplaced_values); if (!unplaced_values.empty()) { - absl::StatusOr inferred_type_pb = + absl::StatusOr inferred_type_pb = InferTypeFromValue(unplaced_values[0]); CHECK(inferred_type_pb.ok()) << inferred_type_pb.status(); *clients_value.mutable_federated()->mutable_type()->mutable_member() = @@ -535,21 +535,22 @@ TEST_F(StreamingRemoteExecutorTest, CreateCallFnWithStructReturnType) { v0::Value tensor_three = TensorV(3.0f); // Create a no-arg lambda that returns a structure of tensors. v0::Value fn_value; - v0::Computation* fn_computation = fn_value.mutable_computation(); - v0::FunctionType* fn_type = + federated_language::Computation* fn_computation = + fn_value.mutable_computation(); + federated_language::FunctionType* fn_type = fn_computation->mutable_type()->mutable_function(); fn_type->mutable_parameter()->mutable_tensor()->set_dtype( - v0::DataType::DT_FLOAT); - v0::StructType* result_struct_type = + federated_language::DataType::DT_FLOAT); + federated_language::StructType* result_struct_type = fn_type->mutable_result()->mutable_struct_(); result_struct_type->add_element() ->mutable_value() ->mutable_tensor() - ->set_dtype(v0::DataType::DT_FLOAT); + ->set_dtype(federated_language::DataType::DT_FLOAT); result_struct_type->add_element() ->mutable_value() ->mutable_tensor() - ->set_dtype(v0::DataType::DT_FLOAT); + ->set_dtype(federated_language::DataType::DT_FLOAT); v0::Value arg_value = TensorV(4.0f); v0::Value materialized_value; @@ -870,7 +871,8 @@ TEST_F(StreamingRemoteExecutorTest, CreateSelectionWithError) { struct FederatedStructTestCase { std::function)> FederatedV; - std::function FederatedZipIntrinsicV; + std::function + FederatedZipIntrinsicV; std::string_view placement_uri; bool all_equal; }; @@ -927,7 +929,7 @@ TEST_P(StreamingRemoteExecutorFederatedStructsTest, RoundTripFederatedStruct) { .WillOnce(ReturnOkWithResponseId( "streamed_federated_struct")); - v0::FunctionType zip_at_placement_type_pb; + federated_language::FunctionType zip_at_placement_type_pb; { auto* param_struct_type = zip_at_placement_type_pb.mutable_parameter()->mutable_struct_(); @@ -936,7 +938,7 @@ TEST_P(StreamingRemoteExecutorFederatedStructsTest, RoundTripFederatedStruct) { ->mutable_value() ->mutable_federated(); param_federated_type->mutable_member()->mutable_tensor()->set_dtype( - v0::DataType::DT_FLOAT); + federated_language::DataType::DT_FLOAT); param_federated_type->set_all_equal(test_case.all_equal); param_federated_type->mutable_placement()->mutable_value()->set_uri( std::string(test_case.placement_uri)); @@ -952,7 +954,7 @@ TEST_P(StreamingRemoteExecutorFederatedStructsTest, RoundTripFederatedStruct) { result_struct_type->add_element() ->mutable_value() ->mutable_tensor() - ->set_dtype(v0::DataType::DT_FLOAT); + ->set_dtype(federated_language::DataType::DT_FLOAT); } } EXPECT_CALL( @@ -1206,32 +1208,34 @@ TEST_P(StreamingRemoteExecutorFederatedStructsTest, .WillOnce(ReturnOkWithResponseId( std::string(struct_ref))); - v0::FunctionType zip_at_placement_type_pb; - v0::StructType* param_struct_type = + federated_language::FunctionType zip_at_placement_type_pb; + federated_language::StructType* param_struct_type = zip_at_placement_type_pb.mutable_parameter()->mutable_struct_(); - v0::FederatedType* param_federated_type = + federated_language::FederatedType* param_federated_type = param_struct_type->add_element() ->mutable_value() ->mutable_federated(); param_federated_type->set_all_equal(test_case.all_equal); param_federated_type->mutable_placement()->mutable_value()->set_uri( std::string(test_case.placement_uri)); - v0::Type* param_value_type = param_federated_type->mutable_member(); + federated_language::Type* param_value_type = + param_federated_type->mutable_member(); if (absl::EndsWith(elem_ref, "struct")) { // Nested struct needs another layer. param_value_type = param_value_type->mutable_struct_() ->add_element() ->mutable_value(); } - param_value_type->mutable_tensor()->set_dtype(v0::DataType::DT_FLOAT); - v0::FederatedType* result_federated_type = + param_value_type->mutable_tensor()->set_dtype( + federated_language::DataType::DT_FLOAT); + federated_language::FederatedType* result_federated_type = zip_at_placement_type_pb.mutable_result()->mutable_federated(); result_federated_type->set_all_equal(test_case.all_equal); result_federated_type->mutable_placement()->mutable_value()->set_uri( std::string(test_case.placement_uri)); - v0::StructType* result_struct_type = + federated_language::StructType* result_struct_type = result_federated_type->mutable_member()->mutable_struct_(); - v0::Type* result_value_type = + federated_language::Type* result_value_type = result_struct_type->add_element()->mutable_value(); if (absl::EndsWith(elem_ref, "struct")) { // Nested struct needs another layer. @@ -1239,7 +1243,8 @@ TEST_P(StreamingRemoteExecutorFederatedStructsTest, ->add_element() ->mutable_value(); } - result_value_type->mutable_tensor()->set_dtype(v0::DataType::DT_FLOAT); + result_value_type->mutable_tensor()->set_dtype( + federated_language::DataType::DT_FLOAT); EXPECT_CALL(*mock_executor_service_, CreateValue(_, EqualsProto(CreateValueRequestForValue( @@ -1516,7 +1521,7 @@ INSTANTIATE_TEST_SUITE_P( {[](std::vector values) -> v0::Value { return ClientsV(std::move(values)); }, - [](v0::FunctionType type_pb) -> v0::Value { + [](federated_language::FunctionType type_pb) -> v0::Value { return testing::intrinsic::FederatedZipAtClientsV(type_pb); }, kClientsUri, false}, @@ -1524,7 +1529,7 @@ INSTANTIATE_TEST_SUITE_P( {[](std::vector values) -> v0::Value { return ServerV(values[0]); }, - [](v0::FunctionType type_pb) -> v0::Value { + [](federated_language::FunctionType type_pb) -> v0::Value { return testing::intrinsic::FederatedZipAtServerV(type_pb); }, kServerUri, true}, diff --git a/tensorflow_federated/cc/core/impl/executors/struct_traversal_order.h b/tensorflow_federated/cc/core/impl/executors/struct_traversal_order.h index 38160a4768..b36668598f 100644 --- a/tensorflow_federated/cc/core/impl/executors/struct_traversal_order.h +++ b/tensorflow_federated/cc/core/impl/executors/struct_traversal_order.h @@ -26,12 +26,12 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" +#include "federated_language/proto/computation.pb.h" namespace tensorflow_federated { inline std::vector NamesFromStructType( - const v0::StructType& struct_type) { + const federated_language::StructType& struct_type) { std::vector names; for (const auto& struct_el : struct_type.element()) { if (struct_el.name().length()) { @@ -44,7 +44,7 @@ inline std::vector NamesFromStructType( using NameAndIndex = std::pair; inline absl::StatusOr> TFNestTraversalOrderFromStruct( - const v0::StructType& struct_type) { + const federated_language::StructType& struct_type) { auto struct_names = NamesFromStructType(struct_type); std::vector traversal_order(struct_type.element_size(), 0); // Initialize traversal order as iteration order over the structure. diff --git a/tensorflow_federated/cc/core/impl/executors/struct_traversal_order_test.cc b/tensorflow_federated/cc/core/impl/executors/struct_traversal_order_test.cc index f17369cc80..e9b9a90229 100644 --- a/tensorflow_federated/cc/core/impl/executors/struct_traversal_order_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/struct_traversal_order_test.cc @@ -33,7 +33,7 @@ using ::testing::HasSubstr; TEST(TFNestTraversalOrderFromStructTest, UnnamedStructureReturnsIterationOrder) { - v0::Type struct_type; + federated_language::Type struct_type; struct_type.mutable_struct_()->mutable_element()->Add(); struct_type.mutable_struct_()->mutable_element()->Add(); std::vector expected_order = {0, 1}; @@ -44,7 +44,7 @@ TEST(TFNestTraversalOrderFromStructTest, } TEST(TFNestTraversalOrderFromStructTest, NamedStructureReturnsKeySortedOrder) { - v0::Type struct_type; + federated_language::Type struct_type; *struct_type.mutable_struct_()->mutable_element()->Add()->mutable_name() = "b"; *struct_type.mutable_struct_()->mutable_element()->Add()->mutable_name() = @@ -57,7 +57,7 @@ TEST(TFNestTraversalOrderFromStructTest, NamedStructureReturnsKeySortedOrder) { } TEST(TFNestTraversalOrderFromStructTest, PartiallyNamedStructureErrs) { - v0::Type struct_type; + federated_language::Type struct_type; *struct_type.mutable_struct_()->mutable_element()->Add()->mutable_name() = "a"; struct_type.mutable_struct_()->mutable_element()->Add(); diff --git a/tensorflow_federated/cc/core/impl/executors/tensor_serialization.cc b/tensorflow_federated/cc/core/impl/executors/tensor_serialization.cc index 46212d1072..e0e526ea75 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensor_serialization.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensor_serialization.cc @@ -18,14 +18,14 @@ limitations under the License #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow_federated/cc/core/impl/executors/array_shape_utils.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -34,7 +34,7 @@ absl::Status SerializeTensorValue(const tensorflow::Tensor tensor, v0::Value* value_pb) { // Repeated fields are used for strings and constants to maintain // compatibility with TensorFlow. - v0::Array array_pb; + federated_language::Array array_pb; if ((tensor.shape().dims() == 0 && !tensor.shape().unknown_rank()) || tensor.dtype() == tensorflow::DataType::DT_STRING) { array_pb = TFF_TRY(ArrayFromTensor(tensor)); @@ -58,7 +58,7 @@ absl::StatusOr DeserializeTensorValue( // Repeated fields are used for strings and constants to maintain // compatibility with TensorFlow. if (tensorflow_federated::IsScalar(value_pb.array().shape()) || - value_pb.array().dtype() == v0::DataType::DT_STRING) { + value_pb.array().dtype() == federated_language::DataType::DT_STRING) { return TensorFromArray(value_pb.array()); } else { return TensorFromArrayContent(value_pb.array()); diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_executor.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_executor.cc index 235b766915..1ac8fe84df 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_executor.cc @@ -50,6 +50,7 @@ limitations under the License #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/public/session.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/dataset_from_tensor_structures.h" #include "tensorflow_federated/cc/core/impl/executors/dataset_utils.h" #include "tensorflow_federated/cc/core/impl/executors/executor.h" @@ -58,7 +59,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" #include "tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -166,14 +166,15 @@ void AddDatasetToGraphOp(tensorflow::GraphDef& graphdef_pb, // │dependent│ // └─────────┘ // -// This is used on parameter bindings of `v0::TensorFlow` computations. This is -// the reverse of `AddSerializationOpsForResults`, which is used on the result -// bindings of the function. +// This is used on parameter bindings of `federated_language::TensorFlow` +// computations. This is the reverse of `AddSerializationOpsForResults`, which +// is used on the result bindings of the function. absl::Status AddDeserializationOpsForParameters( - tensorflow::GraphDef& graphdef_pb, v0::TensorFlow::Binding& binding, + tensorflow::GraphDef& graphdef_pb, + federated_language::TensorFlow::Binding& binding, std::string_view prefix = "root") { switch (binding.binding_case()) { - case v0::TensorFlow::Binding::kSequence: { + case federated_language::TensorFlow::Binding::kSequence: { // Get a copy of the name of the placeholder we're operating on. We're // going to clear/reset the binding and then rebuild the it but re-use // the placeholder op. @@ -220,7 +221,7 @@ absl::Status AddDeserializationOpsForParameters( dataset_placeholder_node_name); return absl::OkStatus(); } - case v0::TensorFlow::Binding::kStruct: { + case federated_language::TensorFlow::Binding::kStruct: { for (int i = 0; i < binding.struct_().element_size(); ++i) { auto& member = *binding.mutable_struct_()->mutable_element(i); TFF_TRY(AddDeserializationOpsForParameters( @@ -262,16 +263,18 @@ absl::Status AddDeserializationOpsForParameters( // │dataset variant tensor│ // └──────────────────────┘ // -// This is used on result bindings of `v0::TensorFlow` computations. This is -// the reverse of `AddDeserializationOpsForParameters`, which is used on the -// parameter bindings of the function. -absl::Status AddSerializationOpsForResults(tensorflow::GraphDef& graphdef_pb, - v0::TensorFlow::Binding& binding, - std::string_view prefix = "root") { +// This is used on result bindings of `federated_language::TensorFlow` +// computations. This is the reverse of `AddDeserializationOpsForParameters`, +// which is used on the parameter bindings of the function. +absl::Status AddSerializationOpsForResults( + tensorflow::GraphDef& graphdef_pb, + federated_language::TensorFlow::Binding& binding, + std::string_view prefix = "root") { switch (binding.binding_case()) { - case v0::TensorFlow::Binding::kSequence: { + case federated_language::TensorFlow::Binding::kSequence: { if (binding.sequence().binding_case() == - v0::TensorFlow::SequenceBinding::kGraphDefTensorName) { + federated_language::TensorFlow::SequenceBinding:: + kGraphDefTensorName) { // Already using the correct binding, simply return. return absl::OkStatus(); } @@ -293,7 +296,7 @@ absl::Status AddSerializationOpsForResults(tensorflow::GraphDef& graphdef_pb, graph_names.graph_def_tensor_name); return absl::OkStatus(); } - case v0::TensorFlow::Binding::kStruct: { + case federated_language::TensorFlow::Binding::kStruct: { for (int i = 0; i < binding.struct_().element_size(); ++i) { auto& member = *binding.mutable_struct_()->mutable_element(i); TFF_TRY(AddSerializationOpsForResults(graphdef_pb, member, @@ -311,8 +314,8 @@ absl::Status AddSerializationOpsForResults(tensorflow::GraphDef& graphdef_pb, absl::Status AddDatasetSerializationToSequenceBindings( tensorflow::GraphDef& graphdef_pb, - std::optional& parameter_binding, - v0::TensorFlow::Binding& result_binding) { + std::optional& parameter_binding, + federated_language::TensorFlow::Binding& result_binding) { if (parameter_binding != std::nullopt) { TFF_TRY(AddDeserializationOpsForParameters(graphdef_pb, parameter_binding.value())); @@ -326,16 +329,16 @@ absl::Status AddDatasetSerializationToSequenceBindings( class Computation { public: static absl::StatusOr> FromProto( - const v0::TensorFlow& comp_pb) { + const federated_language::TensorFlow& comp_pb) { tensorflow::GraphDef graphdef_pb; if (!comp_pb.graph_def().UnpackTo(&graphdef_pb)) { return absl::InternalError(ERR_LOG("Could not unpack graphdef proto")); } - std::optional parameter_shape; + std::optional parameter_shape; if (comp_pb.has_parameter()) { parameter_shape = comp_pb.parameter(); } - v0::TensorFlow::Binding result_shape = comp_pb.result(); + federated_language::TensorFlow::Binding result_shape = comp_pb.result(); TFF_TRY(AddDatasetSerializationToSequenceBindings( graphdef_pb, parameter_shape, result_shape)); std::vector output_tensor_names; @@ -348,10 +351,11 @@ class Computation { absl::StatusOr Call(std::optional arg); - Computation(tensorflow::GraphDef graph, std::string init_op, - std::optional parameter_shape, - v0::TensorFlow::Binding output_shape, - std::vector output_tensor_names) + Computation( + tensorflow::GraphDef graph, std::string init_op, + std::optional parameter_shape, + federated_language::TensorFlow::Binding output_shape, + std::vector output_tensor_names) : session_provider_(std::move(graph)), init_op_(std::move(init_op)), parameter_shape_(std::move(parameter_shape)), @@ -368,20 +372,20 @@ class Computation { private: static absl::Status TensorNamesFromBinding( - const v0::TensorFlow::Binding& binding, + const federated_language::TensorFlow::Binding& binding, std::vector* tensor_names) { switch (binding.binding_case()) { - case v0::TensorFlow::Binding::kTensor: { + case federated_language::TensorFlow::Binding::kTensor: { tensor_names->push_back(binding.tensor().tensor_name()); return absl::OkStatus(); } - case v0::TensorFlow::Binding::kStruct: { + case federated_language::TensorFlow::Binding::kStruct: { for (const auto& member : binding.struct_().element()) { TFF_TRY(TensorNamesFromBinding(member, tensor_names)); } return absl::OkStatus(); } - case v0::TensorFlow::Binding::kSequence: { + case federated_language::TensorFlow::Binding::kSequence: { tensor_names->push_back(binding.sequence().graph_def_tensor_name()); return absl::OkStatus(); } @@ -400,8 +404,8 @@ class Computation { SessionProvider session_provider_; std::string init_op_; - std::optional parameter_shape_; - v0::TensorFlow::Binding output_shape_; + std::optional parameter_shape_; + federated_language::TensorFlow::Binding output_shape_; std::vector output_tensor_names_; }; @@ -516,7 +520,7 @@ class ExecutorValue { Intrinsic intrinsic() const { return std::get(value_); } absl::Status Bind( - const v0::TensorFlow::Binding& shape, + const federated_language::TensorFlow::Binding& shape, std::vector>* bindings) const { switch (type()) { case ValueType::TENSOR: { @@ -595,15 +599,15 @@ class ExecutorValue { } static absl::StatusOr FromTensorsAndBindingStructure( - const v0::TensorFlow::Binding& binding_structure, + const federated_language::TensorFlow::Binding& binding_structure, absl::Span* tensors) { bool is_sequence = false; switch (binding_structure.binding_case()) { - case v0::TensorFlow::Binding::kSequence: { + case federated_language::TensorFlow::Binding::kSequence: { is_sequence = true; } TF_FALLTHROUGH_INTENDED; - case v0::TensorFlow::Binding::kTensor: { + case federated_language::TensorFlow::Binding::kTensor: { if (tensors->empty()) { return absl::InternalError( "TensorFlow computation had fewer output tensors than expected."); @@ -616,7 +620,7 @@ class ExecutorValue { return ExecutorValue(std::move(tensor)); } } - case v0::TensorFlow::Binding::kStruct: { + case federated_language::TensorFlow::Binding::kStruct: { auto elements = std::make_shared>(); elements->reserve(binding_structure.struct_().element_size()); for (const auto& e_structure : binding_structure.struct_().element()) { @@ -659,8 +663,9 @@ class ExecutorValue { std::shared_ptr>, Intrinsic> value_; - static absl::Status BindKindMismatch(const std::string_view value_kind, - const v0::TensorFlow::Binding& shape) { + static absl::Status BindKindMismatch( + const std::string_view value_kind, + const federated_language::TensorFlow::Binding& shape) { return absl::InvalidArgumentError( absl::StrCat("Attempted to bind ", value_kind, " value to argument of kind ", shape.Utf8DebugString())); @@ -841,9 +846,9 @@ class TensorFlowExecutor : public ExecutorBase { } absl::StatusOr CreateValueComputation( - const v0::Computation& comp_pb) { + const federated_language::Computation& comp_pb) { switch (comp_pb.computation_case()) { - case v0::Computation::kTensorflow: { + case federated_language::Computation::kTensorflow: { if (!comp_pb.tensorflow().has_cache_key() || comp_pb.tensorflow().cache_key().id() == 0) { // No ID to use for caching, simply create a computation and skip @@ -879,12 +884,12 @@ class TensorFlowExecutor : public ExecutorBase { } return ExecutorValue(computation); } - case v0::Computation::kLiteral: { + case federated_language::Computation::kLiteral: { const tensorflow::Tensor tensor = TFF_TRY(TensorFromArray(comp_pb.literal().value())); return ExecutorValue(std::move(tensor)); } - case v0::Computation::kIntrinsic: { + case federated_language::Computation::kIntrinsic: { Intrinsic intrinsic = TFF_TRY(IntrinsicFromUri(comp_pb.intrinsic().uri())); return ExecutorValue(intrinsic); diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_executor_parameterized_test.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_executor_parameterized_test.cc index 4bd5c583fe..abd5f0e32c 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_executor_parameterized_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_executor_parameterized_test.cc @@ -30,6 +30,9 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/types/span.h" #include "google/protobuf/io/zero_copy_stream_impl.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/tf_status.h" @@ -52,9 +55,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" ABSL_FLAG(std::string, reduce_graph_path, "", @@ -74,24 +74,24 @@ using ::testing::HasSubstr; using ::testing::Types; template -inline v0::TensorFlow::Binding TensorB(const TfOp& op) { +inline federated_language::TensorFlow::Binding TensorB(const TfOp& op) { const tensorflow::Node* node = op.node(); - v0::TensorFlow::Binding binding; + federated_language::TensorFlow::Binding binding; *binding.mutable_tensor()->mutable_tensor_name() = node->name(); return binding; } template -inline v0::TensorFlow::Binding SequenceB(const TfOp& op) { +inline federated_language::TensorFlow::Binding SequenceB(const TfOp& op) { const tensorflow::Node* node = op.node(); - v0::TensorFlow::Binding binding; + federated_language::TensorFlow::Binding binding; *binding.mutable_sequence()->mutable_variant_tensor_name() = node->name(); return binding; } -inline v0::TensorFlow::Binding StructB( - const absl::Span elements) { - v0::TensorFlow::Binding binding; +inline federated_language::TensorFlow::Binding StructB( + const absl::Span elements) { + federated_language::TensorFlow::Binding binding; auto struct_mut = binding.mutable_struct_(); for (const auto& element : elements) { *struct_mut->add_element() = element; @@ -122,14 +122,15 @@ ExecutorId ExecutorType() { return kTensorFlowExecutor; } inline v0::Value ComputationV( - std::optional in_binding, - v0::TensorFlow::Binding out_binding, const tensorflow::Scope& scope, + std::optional in_binding, + federated_language::TensorFlow::Binding out_binding, + const tensorflow::Scope& scope, const std::optional& init_op = std::nullopt) { v0::Value value_pb; - v0::Computation* comp_pb = value_pb.mutable_computation(); + federated_language::Computation* comp_pb = value_pb.mutable_computation(); // NOTE: we do not fill in the `type` field of `comp` because it is not needed // by the C++ TensorFlow executor. - v0::TensorFlow* tensorflow_pb = comp_pb->mutable_tensorflow(); + federated_language::TensorFlow* tensorflow_pb = comp_pb->mutable_tensorflow(); tensorflow::GraphDef graphdef_pb; tensorflow::Status status = scope.ToGraphDef(&graphdef_pb); CHECK(status.ok()) << status; @@ -280,10 +281,10 @@ char const* const kReduceResultOutputTensorName = "result_tensor"; // input dataset of `int64_t`s and return the sum of the elements. v0::Value CreateDatasetReduceComputationV() { v0::Value value_pb; - v0::Computation* comp_pb = value_pb.mutable_computation(); + federated_language::Computation* comp_pb = value_pb.mutable_computation(); // NOTE: we do not fill in the `type` field of `comp` because it is not needed // by the C++ TensorFlow executor. - v0::TensorFlow* tensorflow_pb = comp_pb->mutable_tensorflow(); + federated_language::TensorFlow* tensorflow_pb = comp_pb->mutable_tensorflow(); std::string reduce_graph_path = absl::GetFlag(FLAGS_reduce_graph_path); tensorflow::GraphDef graphdef_pb = LoadGraph(reduce_graph_path.c_str()); tensorflow_pb->mutable_graph_def()->PackFrom(graphdef_pb); @@ -600,12 +601,14 @@ class TensorFlowExecutorTest : public ::testing::Test { }; TEST_F(TensorFlowExecutorTest, CreateValueComputationLiteralReturnsResult) { - const v0::DataType dtype = v0::DataType::DT_INT32; - v0::ArrayShape shape_pb = testing::CreateArrayShape({}); + const federated_language::DataType dtype = + federated_language::DataType::DT_INT32; + federated_language::ArrayShape shape_pb = testing::CreateArrayShape({}); auto values = {1}; - v0::Array array_pb = + federated_language::Array array_pb = TFF_ASSERT_OK(testing::CreateArray(dtype, shape_pb, values)); - v0::Computation computation_pb = testing::LiteralComputation(array_pb); + federated_language::Computation computation_pb = + testing::LiteralComputation(array_pb); v0::Value value_pb = testing::ComputationV(computation_pb); const OwnedValueId& embedded_fn = diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc index 94ec5669f7..4d4eeb705b 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc @@ -28,6 +28,8 @@ limitations under the License #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "third_party/eigen3/Eigen/Core" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -40,46 +42,44 @@ limitations under the License #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" namespace tensorflow_federated { -absl::StatusOr DataTypeFromTensorFlowDataType( +absl::StatusOr DataTypeFromTensorFlowDataType( tensorflow::DataType data_type_pb) { switch (data_type_pb) { case tensorflow::DataType::DT_BOOL: - return v0::DataType::DT_BOOL; + return federated_language::DataType::DT_BOOL; case tensorflow::DataType::DT_INT8: - return v0::DataType::DT_INT8; + return federated_language::DataType::DT_INT8; case tensorflow::DataType::DT_INT16: - return v0::DataType::DT_INT16; + return federated_language::DataType::DT_INT16; case tensorflow::DataType::DT_INT32: - return v0::DataType::DT_INT32; + return federated_language::DataType::DT_INT32; case tensorflow::DataType::DT_INT64: - return v0::DataType::DT_INT64; + return federated_language::DataType::DT_INT64; case tensorflow::DataType::DT_UINT8: - return v0::DataType::DT_UINT8; + return federated_language::DataType::DT_UINT8; case tensorflow::DataType::DT_UINT16: - return v0::DataType::DT_UINT16; + return federated_language::DataType::DT_UINT16; case tensorflow::DataType::DT_UINT32: - return v0::DataType::DT_UINT32; + return federated_language::DataType::DT_UINT32; case tensorflow::DataType::DT_UINT64: - return v0::DataType::DT_UINT64; + return federated_language::DataType::DT_UINT64; case tensorflow::DataType::DT_HALF: - return v0::DataType::DT_HALF; + return federated_language::DataType::DT_HALF; case tensorflow::DataType::DT_FLOAT: - return v0::DataType::DT_FLOAT; + return federated_language::DataType::DT_FLOAT; case tensorflow::DataType::DT_DOUBLE: - return v0::DataType::DT_DOUBLE; + return federated_language::DataType::DT_DOUBLE; case tensorflow::DataType::DT_COMPLEX64: - return v0::DataType::DT_COMPLEX64; + return federated_language::DataType::DT_COMPLEX64; case tensorflow::DataType::DT_COMPLEX128: - return v0::DataType::DT_COMPLEX128; + return federated_language::DataType::DT_COMPLEX128; case tensorflow::DataType::DT_BFLOAT16: - return v0::DataType::DT_BFLOAT16; + return federated_language::DataType::DT_BFLOAT16; case tensorflow::DataType::DT_STRING: - return v0::DataType::DT_STRING; + return federated_language::DataType::DT_STRING; default: return absl::UnimplementedError( absl::StrCat("Unexpected DataType found:", data_type_pb)); @@ -87,39 +87,39 @@ absl::StatusOr DataTypeFromTensorFlowDataType( } absl::StatusOr TensorFlowDataTypeFromDataType( - v0::DataType data_type_pb) { + federated_language::DataType data_type_pb) { switch (data_type_pb) { - case v0::DataType::DT_BOOL: + case federated_language::DataType::DT_BOOL: return tensorflow::DataType::DT_BOOL; - case v0::DataType::DT_INT8: + case federated_language::DataType::DT_INT8: return tensorflow::DataType::DT_INT8; - case v0::DataType::DT_INT16: + case federated_language::DataType::DT_INT16: return tensorflow::DataType::DT_INT16; - case v0::DataType::DT_INT32: + case federated_language::DataType::DT_INT32: return tensorflow::DataType::DT_INT32; - case v0::DataType::DT_INT64: + case federated_language::DataType::DT_INT64: return tensorflow::DataType::DT_INT64; - case v0::DataType::DT_UINT8: + case federated_language::DataType::DT_UINT8: return tensorflow::DataType::DT_UINT8; - case v0::DataType::DT_UINT16: + case federated_language::DataType::DT_UINT16: return tensorflow::DataType::DT_UINT16; - case v0::DataType::DT_UINT32: + case federated_language::DataType::DT_UINT32: return tensorflow::DataType::DT_UINT32; - case v0::DataType::DT_UINT64: + case federated_language::DataType::DT_UINT64: return tensorflow::DataType::DT_UINT64; - case v0::DataType::DT_HALF: + case federated_language::DataType::DT_HALF: return tensorflow::DataType::DT_HALF; - case v0::DataType::DT_FLOAT: + case federated_language::DataType::DT_FLOAT: return tensorflow::DataType::DT_FLOAT; - case v0::DataType::DT_DOUBLE: + case federated_language::DataType::DT_DOUBLE: return tensorflow::DataType::DT_DOUBLE; - case v0::DataType::DT_COMPLEX64: + case federated_language::DataType::DT_COMPLEX64: return tensorflow::DataType::DT_COMPLEX64; - case v0::DataType::DT_COMPLEX128: + case federated_language::DataType::DT_COMPLEX128: return tensorflow::DataType::DT_COMPLEX128; - case v0::DataType::DT_BFLOAT16: + case federated_language::DataType::DT_BFLOAT16: return tensorflow::DataType::DT_BFLOAT16; - case v0::DataType::DT_STRING: + case federated_language::DataType::DT_STRING: return tensorflow::DataType::DT_STRING; default: return absl::UnimplementedError( @@ -127,9 +127,9 @@ absl::StatusOr TensorFlowDataTypeFromDataType( } } -absl::StatusOr ArrayShapeFromTensorShape( +absl::StatusOr ArrayShapeFromTensorShape( const tensorflow::TensorShape& tensor_shape) { - v0::ArrayShape shape_pb; + federated_language::ArrayShape shape_pb; for (int i = 0; i < tensor_shape.dims(); i++) { shape_pb.mutable_dim()->Add(tensor_shape.dim_size(i)); } @@ -138,10 +138,11 @@ absl::StatusOr ArrayShapeFromTensorShape( } absl::StatusOr TensorShapeFromArrayShape( - const v0::ArrayShape& shape_pb) { + const federated_language::ArrayShape& shape_pb) { if (shape_pb.unknown_rank()) { return absl::InvalidArgumentError( - "Expected v0::ArrayShape to have a known rank, try constructing " + "Expected federated_language::ArrayShape to have a known rank, try " + "constructing " "a tensorflow::PartialTensorShape using " "tensorflow_federated::PartialTensorShapeFromArrayShape instead."); } @@ -152,7 +153,7 @@ absl::StatusOr TensorShapeFromArrayShape( } tensorflow::PartialTensorShape PartialTensorShapeFromArrayShape( - const v0::ArrayShape& shape_pb) { + const federated_language::ArrayShape& shape_pb) { if (!shape_pb.unknown_rank()) { return tensorflow::PartialTensorShape(shape_pb.dim()); } else { @@ -160,12 +161,14 @@ tensorflow::PartialTensorShape PartialTensorShapeFromArrayShape( } } -absl::StatusOr ArrayFromTensor(const tensorflow::Tensor& tensor) { - v0::Array array_pb; - v0::DataType data_type = +absl::StatusOr ArrayFromTensor( + const tensorflow::Tensor& tensor) { + federated_language::Array array_pb; + federated_language::DataType data_type = TFF_TRY(DataTypeFromTensorFlowDataType(tensor.dtype())); array_pb.set_dtype(data_type); - v0::ArrayShape shape_pb = TFF_TRY(ArrayShapeFromTensorShape(tensor.shape())); + federated_language::ArrayShape shape_pb = + TFF_TRY(ArrayShapeFromTensorShape(tensor.shape())); array_pb.mutable_shape()->Swap(&shape_pb); tensorflow::TensorProto tensor_pb; @@ -313,9 +316,10 @@ static void CopyFromRepeatedField( std::copy(src.begin(), src.end(), dest); } -absl::StatusOr TensorFromArray(const v0::Array& array_pb) { +absl::StatusOr TensorFromArray( + const federated_language::Array& array_pb) { switch (array_pb.kind_case()) { - case v0::Array::kBoolList: { + case federated_language::Array::kBoolList: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -323,7 +327,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kInt8List: { + case federated_language::Array::kInt8List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -331,7 +335,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kInt16List: { + case federated_language::Array::kInt16List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -339,7 +343,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kInt32List: { + case federated_language::Array::kInt32List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -347,7 +351,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kInt64List: { + case federated_language::Array::kInt64List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -355,7 +359,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kUint8List: { + case federated_language::Array::kUint8List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -363,7 +367,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kUint16List: { + case federated_language::Array::kUint16List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -371,7 +375,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kUint32List: { + case federated_language::Array::kUint32List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -379,7 +383,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kUint64List: { + case federated_language::Array::kUint64List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -387,7 +391,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kFloat16List: { + case federated_language::Array::kFloat16List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -395,7 +399,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kFloat32List: { + case federated_language::Array::kFloat32List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -403,7 +407,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kFloat64List: { + case federated_language::Array::kFloat64List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -411,7 +415,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kComplex64List: { + case federated_language::Array::kComplex64List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -419,7 +423,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kComplex128List: { + case federated_language::Array::kComplex128List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -427,7 +431,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kBfloat16List: { + case federated_language::Array::kBfloat16List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -435,7 +439,7 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } - case v0::Array::kStringList: { + case federated_language::Array::kStringList: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); @@ -449,13 +453,14 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { } } -absl::StatusOr ArrayContentFromTensor( +absl::StatusOr ArrayContentFromTensor( const tensorflow::Tensor& tensor) { - v0::Array array_pb; - v0::DataType data_type = + federated_language::Array array_pb; + federated_language::DataType data_type = TFF_TRY(DataTypeFromTensorFlowDataType(tensor.dtype())); array_pb.set_dtype(data_type); - v0::ArrayShape shape_pb = TFF_TRY(ArrayShapeFromTensorShape(tensor.shape())); + federated_language::ArrayShape shape_pb = + TFF_TRY(ArrayShapeFromTensorShape(tensor.shape())); array_pb.mutable_shape()->Swap(&shape_pb); tensorflow::TensorProto tensor_pb; tensor.AsProtoTensorContent(&tensor_pb); @@ -465,7 +470,7 @@ absl::StatusOr ArrayContentFromTensor( } absl::StatusOr TensorFromArrayContent( - const v0::Array& array_pb) { + const federated_language::Array& array_pb) { if (!array_pb.has_content()) { return absl::InvalidArgumentError("Expected a content field, found none."); } diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h index ee481a6ac5..2803e9406c 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.h @@ -20,29 +20,32 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "federated_language/proto/array.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow_federated/proto/v0/array.pb.h" namespace tensorflow_federated { -// Creates a tensorflow::TensorShape from a v0::ArrayShape. +// Creates a tensorflow::TensorShape from a federated_language::ArrayShape. absl::StatusOr TensorShapeFromArrayShape( - const v0::ArrayShape& shape_pb); + const federated_language::ArrayShape& shape_pb); -// Creates a tensorflow::PartialTensorShape from a v0::ArrayShape. +// Creates a tensorflow::PartialTensorShape from a +// federated_language::ArrayShape. tensorflow::PartialTensorShape PartialTensorShapeFromArrayShape( - const v0::ArrayShape& shape_pb); + const federated_language::ArrayShape& shape_pb); -// Creates an v0::Array from a tensorflow::Tensor. -absl::StatusOr ArrayFromTensor(const tensorflow::Tensor& tensor); -absl::StatusOr ArrayContentFromTensor( +// Creates an federated_language::Array from a tensorflow::Tensor. +absl::StatusOr ArrayFromTensor( + const tensorflow::Tensor& tensor); +absl::StatusOr ArrayContentFromTensor( const tensorflow::Tensor& tensor); -// Creates a tensorflow::Tensor from an v0::Array. -absl::StatusOr TensorFromArray(const v0::Array& array_pb); +// Creates a tensorflow::Tensor from an federated_language::Array. +absl::StatusOr TensorFromArray( + const federated_language::Array& array_pb); absl::StatusOr TensorFromArrayContent( - const v0::Array& array_pb); + const federated_language::Array& array_pb); std::string GetNodeName(absl::string_view tensor_name); diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc index 97de352076..5b13a6d61c 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc @@ -26,6 +26,8 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "third_party/eigen3/Eigen/Core" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -36,14 +38,13 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/array_test_utils.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" namespace tensorflow_federated { namespace { TEST(TensorShapeFromArrayShapeTest, TestReturnsTensorShape_fully_defined) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({2, 3}); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({2, 3}); const tensorflow::TensorShape& expected_shape = tensorflow::TensorShape({2, 3}); @@ -54,7 +55,8 @@ TEST(TensorShapeFromArrayShapeTest, TestReturnsTensorShape_fully_defined) { } TEST(TensorShapeFromArrayShapeTest, TestReturnsTensorShape_scalar) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({}); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({}); const tensorflow::TensorShape& expected_shape = tensorflow::TensorShape({}); const tensorflow::TensorShape& actual_shape = @@ -64,7 +66,8 @@ TEST(TensorShapeFromArrayShapeTest, TestReturnsTensorShape_scalar) { } TEST(TensorShapeFromArrayShapeTest, TestFails_partially_defined) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({2, -1}); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({2, -1}); const absl::StatusOr& result = TensorShapeFromArrayShape(shape_pb); @@ -73,7 +76,8 @@ TEST(TensorShapeFromArrayShapeTest, TestFails_partially_defined) { } TEST(TensorShapeFromArrayShapeTest, TestFails_unknown) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({}, true); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({}, true); const absl::StatusOr& result = TensorShapeFromArrayShape(shape_pb); @@ -83,7 +87,7 @@ TEST(TensorShapeFromArrayShapeTest, TestFails_unknown) { struct PartialTensorShapeFromArrayShapeTestCase { std::string test_name; - const v0::ArrayShape shape_pb; + const federated_language::ArrayShape shape_pb; const tensorflow::PartialTensorShape expected_shape; }; @@ -132,7 +136,7 @@ INSTANTIATE_TEST_SUITE_P( struct ArrayFromTensorTestCase { std::string test_name; const tensorflow::Tensor tensor; - const v0::Array expected_array_pb; + const federated_language::Array expected_array_pb; }; using ArrayFromTensorTest = ::testing::TestWithParam; @@ -140,7 +144,7 @@ using ArrayFromTensorTest = ::testing::TestWithParam; TEST_P(ArrayFromTensorTest, TestReturnsTensor) { const ArrayFromTensorTestCase& test_case = GetParam(); - const v0::Array& actual_array_pb = + const federated_language::Array& actual_array_pb = TFF_ASSERT_OK(ArrayFromTensor(test_case.tensor)); EXPECT_THAT(actual_array_pb, @@ -153,70 +157,70 @@ INSTANTIATE_TEST_SUITE_P( { "bool", tensorflow::test::AsScalar(true), - testing::CreateArray(v0::DataType::DT_BOOL, + testing::CreateArray(federated_language::DataType::DT_BOOL, testing::CreateArrayShape({}), {true}) .value(), }, { "int8", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_INT8, + testing::CreateArray(federated_language::DataType::DT_INT8, testing::CreateArrayShape({}), {1}) .value(), }, { "int16", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_INT16, + testing::CreateArray(federated_language::DataType::DT_INT16, testing::CreateArrayShape({}), {1}) .value(), }, { "int32", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_INT32, + testing::CreateArray(federated_language::DataType::DT_INT32, testing::CreateArrayShape({}), {1}) .value(), }, { "int64", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_INT64, + testing::CreateArray(federated_language::DataType::DT_INT64, testing::CreateArrayShape({}), {1}) .value(), }, { "uint8", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_UINT8, + testing::CreateArray(federated_language::DataType::DT_UINT8, testing::CreateArrayShape({}), {1}) .value(), }, { "uint16", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_UINT16, + testing::CreateArray(federated_language::DataType::DT_UINT16, testing::CreateArrayShape({}), {1}) .value(), }, { "uint32", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_UINT32, + testing::CreateArray(federated_language::DataType::DT_UINT32, testing::CreateArrayShape({}), {1}) .value(), }, { "uint64", tensorflow::test::AsScalar(1), - testing::CreateArray(v0::DataType::DT_UINT64, + testing::CreateArray(federated_language::DataType::DT_UINT64, testing::CreateArrayShape({}), {1}) .value(), }, { "float16", tensorflow::test::AsScalar(Eigen::half{1.0}), - testing::CreateArray(v0::DataType::DT_HALF, + testing::CreateArray(federated_language::DataType::DT_HALF, testing::CreateArrayShape({}), {Eigen::half{1.0}}) .value(), @@ -224,21 +228,21 @@ INSTANTIATE_TEST_SUITE_P( { "float32", tensorflow::test::AsScalar(1.0), - testing::CreateArray(v0::DataType::DT_FLOAT, + testing::CreateArray(federated_language::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0}) .value(), }, { "float64", tensorflow::test::AsScalar(1.0), - testing::CreateArray(v0::DataType::DT_DOUBLE, + testing::CreateArray(federated_language::DataType::DT_DOUBLE, testing::CreateArrayShape({}), {1.0}) .value(), }, { "complex64", tensorflow::test::AsScalar(tensorflow::complex64{1.0, 1.0}), - testing::CreateArray(v0::DataType::DT_COMPLEX64, + testing::CreateArray(federated_language::DataType::DT_COMPLEX64, testing::CreateArrayShape({}), {std::complex(1.0, 1.0)}) .value(), @@ -246,7 +250,7 @@ INSTANTIATE_TEST_SUITE_P( { "complex128", tensorflow::test::AsScalar(tensorflow::complex128{1.0, 1.0}), - testing::CreateArray(v0::DataType::DT_COMPLEX128, + testing::CreateArray(federated_language::DataType::DT_COMPLEX128, testing::CreateArrayShape({}), {std::complex(1.0, 1.0)}) .value(), @@ -254,7 +258,7 @@ INSTANTIATE_TEST_SUITE_P( { "bfloat16", tensorflow::test::AsScalar(Eigen::bfloat16{1.0}), - testing::CreateArray(v0::DataType::DT_BFLOAT16, + testing::CreateArray(federated_language::DataType::DT_BFLOAT16, testing::CreateArrayShape({}), {Eigen::bfloat16{1.0}}) .value(), @@ -262,7 +266,7 @@ INSTANTIATE_TEST_SUITE_P( { "string", tensorflow::test::AsScalar("a"), - testing::CreateArray(v0::DataType::DT_STRING, + testing::CreateArray(federated_language::DataType::DT_STRING, testing::CreateArrayShape({}), {"a"}) .value(), }, @@ -270,7 +274,7 @@ INSTANTIATE_TEST_SUITE_P( "array", tensorflow::test::AsTensor( {1, 2, 3, 4, 5, 6}, tensorflow::TensorShape({2, 3})), - testing::CreateArray(v0::DataType::DT_INT32, + testing::CreateArray(federated_language::DataType::DT_INT32, testing::CreateArrayShape({2, 3}), {1, 2, 3, 4, 5, 6}) .value(), @@ -282,7 +286,7 @@ INSTANTIATE_TEST_SUITE_P( struct TensorFromArrayTestCase { std::string test_name; - const v0::Array array_pb; + const federated_language::Array array_pb; const tensorflow::Tensor expected_tensor; }; @@ -302,70 +306,70 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn({ { "bool", - testing::CreateArray(v0::DataType::DT_BOOL, + testing::CreateArray(federated_language::DataType::DT_BOOL, testing::CreateArrayShape({}), {true}) .value(), tensorflow::test::AsScalar(true), }, { "int8", - testing::CreateArray(v0::DataType::DT_INT8, + testing::CreateArray(federated_language::DataType::DT_INT8, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "int16", - testing::CreateArray(v0::DataType::DT_INT16, + testing::CreateArray(federated_language::DataType::DT_INT16, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "int32", - testing::CreateArray(v0::DataType::DT_INT32, + testing::CreateArray(federated_language::DataType::DT_INT32, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "int64", - testing::CreateArray(v0::DataType::DT_INT64, + testing::CreateArray(federated_language::DataType::DT_INT64, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "uint8", - testing::CreateArray(v0::DataType::DT_UINT8, + testing::CreateArray(federated_language::DataType::DT_UINT8, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "uint16", - testing::CreateArray(v0::DataType::DT_UINT16, + testing::CreateArray(federated_language::DataType::DT_UINT16, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "uint32", - testing::CreateArray(v0::DataType::DT_UINT32, + testing::CreateArray(federated_language::DataType::DT_UINT32, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "uint64", - testing::CreateArray(v0::DataType::DT_UINT64, + testing::CreateArray(federated_language::DataType::DT_UINT64, testing::CreateArrayShape({}), {1}) .value(), tensorflow::test::AsScalar(1), }, { "float16", - testing::CreateArray(v0::DataType::DT_HALF, + testing::CreateArray(federated_language::DataType::DT_HALF, testing::CreateArrayShape({}), {Eigen::half{1.0}}) .value(), @@ -373,21 +377,21 @@ INSTANTIATE_TEST_SUITE_P( }, { "float32", - testing::CreateArray(v0::DataType::DT_FLOAT, + testing::CreateArray(federated_language::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0}) .value(), tensorflow::test::AsScalar(1.0), }, { "float64", - testing::CreateArray(v0::DataType::DT_DOUBLE, + testing::CreateArray(federated_language::DataType::DT_DOUBLE, testing::CreateArrayShape({}), {1.0}) .value(), tensorflow::test::AsScalar(1.0), }, { "complex64", - testing::CreateArray(v0::DataType::DT_COMPLEX64, + testing::CreateArray(federated_language::DataType::DT_COMPLEX64, testing::CreateArrayShape({}), {std::complex(1.0, 1.0)}) .value(), @@ -395,7 +399,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "complex128", - testing::CreateArray(v0::DataType::DT_COMPLEX128, + testing::CreateArray(federated_language::DataType::DT_COMPLEX128, testing::CreateArrayShape({}), {std::complex(1.0, 1.0)}) .value(), @@ -403,7 +407,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "bfloat16", - testing::CreateArray(v0::DataType::DT_BFLOAT16, + testing::CreateArray(federated_language::DataType::DT_BFLOAT16, testing::CreateArrayShape({}), {Eigen::bfloat16{1.0}}) .value(), @@ -411,14 +415,14 @@ INSTANTIATE_TEST_SUITE_P( }, { "string", - testing::CreateArray(v0::DataType::DT_STRING, + testing::CreateArray(federated_language::DataType::DT_STRING, testing::CreateArrayShape({}), {"a"}) .value(), tensorflow::test::AsScalar("a"), }, { "array", - testing::CreateArray(v0::DataType::DT_INT32, + testing::CreateArray(federated_language::DataType::DT_INT32, testing::CreateArrayShape({2, 3}), {1, 2, 3, 4, 5, 6}) .value(), @@ -433,7 +437,7 @@ INSTANTIATE_TEST_SUITE_P( struct ArrayContentFromTensorTestCase { std::string test_name; const tensorflow::Tensor tensor; - const v0::Array expected_array_pb; + const federated_language::Array expected_array_pb; }; using ArrayContentFromTensorTest = @@ -442,7 +446,7 @@ using ArrayContentFromTensorTest = TEST_P(ArrayContentFromTensorTest, TestReturnsTensor) { const ArrayContentFromTensorTestCase& test_case = GetParam(); - const v0::Array& actual_array_pb = + const federated_language::Array& actual_array_pb = TFF_ASSERT_OK(ArrayContentFromTensor(test_case.tensor)); EXPECT_THAT(actual_array_pb, @@ -457,7 +461,7 @@ INSTANTIATE_TEST_SUITE_P( { "bool", tensorflow::test::AsScalar(true), - testing::CreateArrayContent(v0::DataType::DT_BOOL, + testing::CreateArrayContent(federated_language::DataType::DT_BOOL, testing::CreateArrayShape({}), CONTENT("\001")) .value(), @@ -465,7 +469,7 @@ INSTANTIATE_TEST_SUITE_P( { "int8", tensorflow::test::AsScalar(1), - testing::CreateArrayContent(v0::DataType::DT_INT8, + testing::CreateArrayContent(federated_language::DataType::DT_INT8, testing::CreateArrayShape({}), CONTENT("\001")) .value(), @@ -473,7 +477,7 @@ INSTANTIATE_TEST_SUITE_P( { "int16", tensorflow::test::AsScalar(1), - testing::CreateArrayContent(v0::DataType::DT_INT16, + testing::CreateArrayContent(federated_language::DataType::DT_INT16, testing::CreateArrayShape({}), CONTENT("\001\000")) .value(), @@ -481,7 +485,7 @@ INSTANTIATE_TEST_SUITE_P( { "int32", tensorflow::test::AsScalar(1), - testing::CreateArrayContent(v0::DataType::DT_INT32, + testing::CreateArrayContent(federated_language::DataType::DT_INT32, testing::CreateArrayShape({}), CONTENT("\001\000\000\000")) .value(), @@ -490,14 +494,15 @@ INSTANTIATE_TEST_SUITE_P( "int64", tensorflow::test::AsScalar(1), testing::CreateArrayContent( - v0::DataType::DT_INT64, testing::CreateArrayShape({}), + federated_language::DataType::DT_INT64, + testing::CreateArrayShape({}), CONTENT("\001\000\000\000\000\000\000\000")) .value(), }, { "uint8", tensorflow::test::AsScalar(1), - testing::CreateArrayContent(v0::DataType::DT_UINT8, + testing::CreateArrayContent(federated_language::DataType::DT_UINT8, testing::CreateArrayShape({}), CONTENT("\001")) .value(), @@ -505,7 +510,7 @@ INSTANTIATE_TEST_SUITE_P( { "uint16", tensorflow::test::AsScalar(1), - testing::CreateArrayContent(v0::DataType::DT_UINT16, + testing::CreateArrayContent(federated_language::DataType::DT_UINT16, testing::CreateArrayShape({}), CONTENT("\001\000")) .value(), @@ -513,7 +518,7 @@ INSTANTIATE_TEST_SUITE_P( { "uint32", tensorflow::test::AsScalar(1), - testing::CreateArrayContent(v0::DataType::DT_UINT32, + testing::CreateArrayContent(federated_language::DataType::DT_UINT32, testing::CreateArrayShape({}), CONTENT("\001\000\000\000")) .value(), @@ -522,14 +527,15 @@ INSTANTIATE_TEST_SUITE_P( "uint64", tensorflow::test::AsScalar(1), testing::CreateArrayContent( - v0::DataType::DT_UINT64, testing::CreateArrayShape({}), + federated_language::DataType::DT_UINT64, + testing::CreateArrayShape({}), CONTENT("\001\000\000\000\000\000\000\000")) .value(), }, { "float16", tensorflow::test::AsScalar(Eigen::half{1.0}), - testing::CreateArrayContent(v0::DataType::DT_HALF, + testing::CreateArrayContent(federated_language::DataType::DT_HALF, testing::CreateArrayShape({}), CONTENT("\000<")) .value(), @@ -537,7 +543,7 @@ INSTANTIATE_TEST_SUITE_P( { "float32", tensorflow::test::AsScalar(1.0), - testing::CreateArrayContent(v0::DataType::DT_FLOAT, + testing::CreateArrayContent(federated_language::DataType::DT_FLOAT, testing::CreateArrayShape({}), CONTENT("\000\000\200?")) .value(), @@ -546,23 +552,26 @@ INSTANTIATE_TEST_SUITE_P( "float64", tensorflow::test::AsScalar(1.0), testing::CreateArrayContent( - v0::DataType::DT_DOUBLE, testing::CreateArrayShape({}), + federated_language::DataType::DT_DOUBLE, + testing::CreateArrayShape({}), CONTENT("\000\000\000\000\000\000\360?")) .value(), }, { "complex64", tensorflow::test::AsScalar(tensorflow::complex64{1.0, 1.0}), - testing::CreateArrayContent(v0::DataType::DT_COMPLEX64, - testing::CreateArrayShape({}), - CONTENT("\000\000\200?\000\000\200?")) + testing::CreateArrayContent( + federated_language::DataType::DT_COMPLEX64, + testing::CreateArrayShape({}), + CONTENT("\000\000\200?\000\000\200?")) .value(), }, { "complex128", tensorflow::test::AsScalar(tensorflow::complex128{1.0, 1.0}), testing::CreateArrayContent( - v0::DataType::DT_COMPLEX128, testing::CreateArrayShape({}), + federated_language::DataType::DT_COMPLEX128, + testing::CreateArrayShape({}), CONTENT("\000\000\000\000\000\000\360?" "\000\000\000\000\000\000\360?")) .value(), @@ -570,9 +579,9 @@ INSTANTIATE_TEST_SUITE_P( { "bfloat16", tensorflow::test::AsScalar(Eigen::bfloat16{1.0}), - testing::CreateArrayContent(v0::DataType::DT_BFLOAT16, - testing::CreateArrayShape({}), - CONTENT("\200?")) + testing::CreateArrayContent( + federated_language::DataType::DT_BFLOAT16, + testing::CreateArrayShape({}), CONTENT("\200?")) .value(), }, { @@ -580,7 +589,7 @@ INSTANTIATE_TEST_SUITE_P( tensorflow::test::AsTensor( {1, 2, 3, 4, 5, 6}, tensorflow::TensorShape({2, 3})), testing::CreateArrayContent( - v0::DataType::DT_INT32, + federated_language::DataType::DT_INT32, testing::CreateArrayShape({2, 3}), CONTENT("\001\000\000\000\002\000\000\000\003\000\000\000\004" "\000\000\000\005\000\000\000\006\000\000\000")) @@ -592,7 +601,7 @@ INSTANTIATE_TEST_SUITE_P( struct TensorFromArrayContentTestCase { std::string test_name; - const v0::Array array_pb; + const federated_language::Array array_pb; const tensorflow::Tensor expected_tensor; }; @@ -615,7 +624,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn({ { "bool", - testing::CreateArrayContent(v0::DataType::DT_BOOL, + testing::CreateArrayContent(federated_language::DataType::DT_BOOL, testing::CreateArrayShape({}), CONTENT("\001")) .value(), @@ -623,7 +632,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "int8", - testing::CreateArrayContent(v0::DataType::DT_INT8, + testing::CreateArrayContent(federated_language::DataType::DT_INT8, testing::CreateArrayShape({}), CONTENT("\001")) .value(), @@ -631,7 +640,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "int16", - testing::CreateArrayContent(v0::DataType::DT_INT16, + testing::CreateArrayContent(federated_language::DataType::DT_INT16, testing::CreateArrayShape({}), CONTENT("\001\000")) .value(), @@ -639,7 +648,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "int32", - testing::CreateArrayContent(v0::DataType::DT_INT32, + testing::CreateArrayContent(federated_language::DataType::DT_INT32, testing::CreateArrayShape({}), CONTENT("\001\000\000\000")) .value(), @@ -648,14 +657,15 @@ INSTANTIATE_TEST_SUITE_P( { "int64", testing::CreateArrayContent( - v0::DataType::DT_INT64, testing::CreateArrayShape({}), + federated_language::DataType::DT_INT64, + testing::CreateArrayShape({}), CONTENT("\001\000\000\000\000\000\000\000")) .value(), tensorflow::test::AsScalar(1), }, { "uint8", - testing::CreateArrayContent(v0::DataType::DT_UINT8, + testing::CreateArrayContent(federated_language::DataType::DT_UINT8, testing::CreateArrayShape({}), CONTENT("\001")) .value(), @@ -663,7 +673,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "uint16", - testing::CreateArrayContent(v0::DataType::DT_UINT16, + testing::CreateArrayContent(federated_language::DataType::DT_UINT16, testing::CreateArrayShape({}), CONTENT("\001\000")) .value(), @@ -671,7 +681,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "uint32", - testing::CreateArrayContent(v0::DataType::DT_UINT32, + testing::CreateArrayContent(federated_language::DataType::DT_UINT32, testing::CreateArrayShape({}), CONTENT("\001\000\000\000")) .value(), @@ -680,14 +690,15 @@ INSTANTIATE_TEST_SUITE_P( { "uint64", testing::CreateArrayContent( - v0::DataType::DT_UINT64, testing::CreateArrayShape({}), + federated_language::DataType::DT_UINT64, + testing::CreateArrayShape({}), CONTENT("\001\000\000\000\000\000\000\000")) .value(), tensorflow::test::AsScalar(1), }, { "float16", - testing::CreateArrayContent(v0::DataType::DT_HALF, + testing::CreateArrayContent(federated_language::DataType::DT_HALF, testing::CreateArrayShape({}), CONTENT("\000<")) .value(), @@ -695,7 +706,7 @@ INSTANTIATE_TEST_SUITE_P( }, { "float32", - testing::CreateArrayContent(v0::DataType::DT_FLOAT, + testing::CreateArrayContent(federated_language::DataType::DT_FLOAT, testing::CreateArrayShape({}), CONTENT("\000\000\200?")) .value(), @@ -704,23 +715,26 @@ INSTANTIATE_TEST_SUITE_P( { "float64", testing::CreateArrayContent( - v0::DataType::DT_DOUBLE, testing::CreateArrayShape({}), + federated_language::DataType::DT_DOUBLE, + testing::CreateArrayShape({}), CONTENT("\000\000\000\000\000\000\360?")) .value(), tensorflow::test::AsScalar(1.0), }, { "complex64", - testing::CreateArrayContent(v0::DataType::DT_COMPLEX64, - testing::CreateArrayShape({}), - CONTENT("\000\000\200?\000\000\200?")) + testing::CreateArrayContent( + federated_language::DataType::DT_COMPLEX64, + testing::CreateArrayShape({}), + CONTENT("\000\000\200?\000\000\200?")) .value(), tensorflow::test::AsScalar(tensorflow::complex64{1.0, 1.0}), }, { "complex128", testing::CreateArrayContent( - v0::DataType::DT_COMPLEX128, testing::CreateArrayShape({}), + federated_language::DataType::DT_COMPLEX128, + testing::CreateArrayShape({}), CONTENT("\000\000\000\000\000\000\360?" "\000\000\000\000\000\000\360?")) .value(), @@ -728,16 +742,16 @@ INSTANTIATE_TEST_SUITE_P( }, { "bfloat16", - testing::CreateArrayContent(v0::DataType::DT_BFLOAT16, - testing::CreateArrayShape({}), - CONTENT("\200?")) + testing::CreateArrayContent( + federated_language::DataType::DT_BFLOAT16, + testing::CreateArrayShape({}), CONTENT("\200?")) .value(), tensorflow::test::AsScalar(Eigen::bfloat16{1.0}), }, { "array", testing::CreateArrayContent( - v0::DataType::DT_INT32, + federated_language::DataType::DT_INT32, testing::CreateArrayShape({2, 3}), CONTENT("\001\000\000\000\002\000\000\000\003\000\000\000\004" "\000\000\000\005\000\000\000\006\000\000\000")) diff --git a/tensorflow_federated/cc/core/impl/executors/type_test_utils.h b/tensorflow_federated/cc/core/impl/executors/type_test_utils.h index 692aaf0777..3dc6694b0a 100644 --- a/tensorflow_federated/cc/core/impl/executors/type_test_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/type_test_utils.h @@ -19,22 +19,23 @@ limitations under the License #include #include "absl/types/span.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" namespace tensorflow_federated { namespace testing { -// Construct a v0::Type of shape > for parameter T. -inline v0::Type NestedStructT(v0::DataType dtype) { - v0::Type float_tensor_type; +// Construct a federated_language::Type of shape > for parameter T. +inline federated_language::Type NestedStructT( + federated_language::DataType dtype) { + federated_language::Type float_tensor_type; float_tensor_type.mutable_tensor()->set_dtype(dtype); - v0::Type nested_struct_type; + federated_language::Type nested_struct_type; for (int i = 0; i < 2; i++) { *nested_struct_type.mutable_struct_()->add_element()->mutable_value() = float_tensor_type; } - v0::Type return_type; + federated_language::Type return_type; *return_type.mutable_struct_()->add_element()->mutable_value() = float_tensor_type; *return_type.mutable_struct_()->add_element()->mutable_value() = @@ -43,9 +44,9 @@ inline v0::Type NestedStructT(v0::DataType dtype) { } // Construct a tensor type with the provided datatype and shape specification. -inline v0::Type TensorT(v0::DataType dtype, - absl::Span shape = {}) { - v0::Type tensor_type; +inline federated_language::Type TensorT(federated_language::DataType dtype, + absl::Span shape = {}) { + federated_language::Type tensor_type; tensor_type.mutable_tensor()->set_dtype(dtype); for (const int64_t& dim : shape) { tensor_type.mutable_tensor()->add_dims(dim); @@ -54,42 +55,51 @@ inline v0::Type TensorT(v0::DataType dtype, } // Construct an unnamed struct type with the provided elements -inline v0::Type StructT(absl::Span elements) { - v0::Type struct_type; - for (const v0::Type& el_type : elements) { +inline federated_language::Type StructT( + absl::Span elements) { + federated_language::Type struct_type; + for (const federated_language::Type& el_type : elements) { *struct_type.mutable_struct_()->add_element()->mutable_value() = el_type; } return struct_type; } -// Construct a functional v0::Type with no argument, and provided return type. -inline v0::Type NoArgFunctionT(v0::Type return_type) { - v0::Type function_type; +// Construct a functional federated_language::Type with no argument, and +// provided return type. +inline federated_language::Type NoArgFunctionT( + federated_language::Type return_type) { + federated_language::Type function_type; *function_type.mutable_function()->mutable_result() = return_type; return function_type; } -// Construct a functional v0::Type with accepting and returning the same type. -inline v0::Type IdentityFunctionT(v0::Type arg_type) { - v0::Type function_type; +// Construct a functional federated_language::Type with accepting and returning +// the same type. +inline federated_language::Type IdentityFunctionT( + federated_language::Type arg_type) { + federated_language::Type function_type; *function_type.mutable_function()->mutable_parameter() = arg_type; *function_type.mutable_function()->mutable_result() = arg_type; return function_type; } -// Construct a functional v0::Type with provided argument and return types. -inline v0::Type FunctionT(v0::Type parameter_type, v0::Type return_type) { - v0::Type function_type; +// Construct a functional federated_language::Type with provided argument and +// return types. +inline federated_language::Type FunctionT( + federated_language::Type parameter_type, + federated_language::Type return_type) { + federated_language::Type function_type; *function_type.mutable_function()->mutable_parameter() = parameter_type; *function_type.mutable_function()->mutable_result() = return_type; return function_type; } -// Construct a v0::Type of shape for parameter T, with num_reps -// elements. -inline v0::Type FlatStructT(v0::DataType dtype, int num_reps) { - v0::Type float_tensor_type; +// Construct a federated_language::Type of shape for parameter T, with +// num_reps elements. +inline federated_language::Type FlatStructT(federated_language::DataType dtype, + int num_reps) { + federated_language::Type float_tensor_type; float_tensor_type.mutable_tensor()->set_dtype(dtype); - v0::Type struct_type; + federated_language::Type struct_type; for (int i = 0; i < num_reps; i++) { *struct_type.mutable_struct_()->add_element()->mutable_value() = float_tensor_type; diff --git a/tensorflow_federated/cc/core/impl/executors/type_utils.cc b/tensorflow_federated/cc/core/impl/executors/type_utils.cc index eef50323b5..38f8159640 100644 --- a/tensorflow_federated/cc/core/impl/executors/type_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/type_utils.cc @@ -19,21 +19,23 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { -absl::StatusOr InferTypeFromValue(const v0::Value& value_pb) { - v0::Type value_type_pb; +absl::StatusOr InferTypeFromValue( + const v0::Value& value_pb) { + federated_language::Type value_type_pb; switch (value_pb.value_case()) { case v0::Value::kArray: { - v0::TensorType* tensor_type_pb = value_type_pb.mutable_tensor(); + federated_language::TensorType* tensor_type_pb = + value_type_pb.mutable_tensor(); tensor_type_pb->set_dtype(value_pb.array().dtype()); tensor_type_pb->mutable_dims()->Assign( value_pb.array().shape().dim().begin(), @@ -41,7 +43,8 @@ absl::StatusOr InferTypeFromValue(const v0::Value& value_pb) { break; } case v0::Value::kStruct: { - v0::StructType* struct_type = value_type_pb.mutable_struct_(); + federated_language::StructType* struct_type = + value_type_pb.mutable_struct_(); for (const v0::Value::Struct::Element& element_pb : value_pb.struct_().element()) { *struct_type->add_element()->mutable_value() = @@ -58,7 +61,8 @@ absl::StatusOr InferTypeFromValue(const v0::Value& value_pb) { break; } case v0::Value::kSequence: { - v0::SequenceType* sequence_type = value_type_pb.mutable_sequence(); + federated_language::SequenceType* sequence_type = + value_type_pb.mutable_sequence(); *sequence_type->mutable_element() = value_pb.sequence().element_type(); break; } diff --git a/tensorflow_federated/cc/core/impl/executors/type_utils.h b/tensorflow_federated/cc/core/impl/executors/type_utils.h index 0d29313f58..a884c0dc9a 100644 --- a/tensorflow_federated/cc/core/impl/executors/type_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/type_utils.h @@ -17,12 +17,13 @@ limitations under the License #define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_TYPE_UTILS_H_ #include "absl/status/statusor.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { -absl::StatusOr InferTypeFromValue(const v0::Value& value_pb); +absl::StatusOr InferTypeFromValue( + const v0::Value& value_pb); } // namespace tensorflow_federated diff --git a/tensorflow_federated/cc/core/impl/executors/value_test_utils.h b/tensorflow_federated/cc/core/impl/executors/value_test_utils.h index 147c7448c7..35fcc12531 100644 --- a/tensorflow_federated/cc/core/impl/executors/value_test_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/value_test_utils.h @@ -34,6 +34,9 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -44,9 +47,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -54,9 +54,10 @@ namespace testing { inline v0::Value IntrinsicV( std::string_view uri, - std::optional type_spec = std::nullopt) { + std::optional type_spec = std::nullopt) { v0::Value value_proto; - v0::Computation* computation_pb = value_proto.mutable_computation(); + federated_language::Computation* computation_pb = + value_proto.mutable_computation(); // Construct an explicit string from this string-view; this silent conversion // is not present in OSS. *computation_pb->mutable_intrinsic()->mutable_uri() = std::string(uri); @@ -69,7 +70,7 @@ inline v0::Value IntrinsicV( // NOTE: Returns a value whose federated type `.member` field is unset. inline v0::Value ServerV(v0::Value server_val) { v0::Value value_proto; - v0::FederatedType* type_proto = + federated_language::FederatedType* type_proto = value_proto.mutable_federated()->mutable_type(); type_proto->set_all_equal(true); *type_proto->mutable_placement()->mutable_value()->mutable_uri() = kServerUri; @@ -81,7 +82,7 @@ inline v0::Value ServerV(v0::Value server_val) { inline v0::Value ClientsV(const absl::Span client_values, bool all_equal = false) { v0::Value value_proto; - v0::FederatedType* type_proto = + federated_language::FederatedType* type_proto = value_proto.mutable_federated()->mutable_type(); type_proto->set_all_equal(all_equal); *type_proto->mutable_placement()->mutable_value()->mutable_uri() = @@ -130,15 +131,15 @@ inline v0::Value SequenceV(int64_t start, int64_t stop, int64_t step) { for (int i = start; i < stop; i += step) { v0::Value::Sequence::Element* element_pb = sequence_pb->add_element(); - v0::Array* array_pb = element_pb->add_flat_value(); - array_pb->set_dtype(v0::DT_INT64); + federated_language::Array* array_pb = element_pb->add_flat_value(); + array_pb->set_dtype(federated_language::DT_INT64); array_pb->mutable_shape()->mutable_dim()->Clear(); array_pb->mutable_int64_list()->add_value(i); } - v0::TensorType* tensor_type_pb = + federated_language::TensorType* tensor_type_pb = sequence_pb->mutable_element_type()->mutable_tensor(); - tensor_type_pb->set_dtype(v0::DataType::DT_INT64); + tensor_type_pb->set_dtype(federated_language::DataType::DT_INT64); tensor_type_pb->add_dims(1); return value_pb; @@ -152,30 +153,31 @@ inline v0::Value SequenceV(std::vector> elements) { for (const std::vector& flat_values : elements) { v0::Value::Sequence::Element* element_pb = sequence_pb->add_element(); for (const int64_t value : flat_values) { - v0::Array* array_pb = element_pb->add_flat_value(); - array_pb->set_dtype(v0::DT_INT64); + federated_language::Array* array_pb = element_pb->add_flat_value(); + array_pb->set_dtype(federated_language::DT_INT64); array_pb->mutable_shape()->mutable_dim()->Clear(); array_pb->mutable_int64_list()->add_value(value); } } - v0::StructType* struct_type_pb = + federated_language::StructType* struct_type_pb = sequence_pb->mutable_element_type()->mutable_struct_(); for (int i = 0; i < elements[0].size(); i++) { - v0::StructType::Element* element_pb = struct_type_pb->add_element(); - v0::TensorType* tensor_type_pb = + federated_language::StructType::Element* element_pb = + struct_type_pb->add_element(); + federated_language::TensorType* tensor_type_pb = element_pb->mutable_value()->mutable_tensor(); - tensor_type_pb->set_dtype(v0::DataType::DT_INT64); + tensor_type_pb->set_dtype(federated_language::DataType::DT_INT64); tensor_type_pb->add_dims(1); } return value_pb; } -inline v0::Type MakeInt64ScalarType() { - v0::Type type; - v0::TensorType* tensor_type = type.mutable_tensor(); - tensor_type->set_dtype(v0::DataType::DT_INT64); +inline federated_language::Type MakeInt64ScalarType() { + federated_language::Type type; + federated_language::TensorType* tensor_type = type.mutable_tensor(); + tensor_type->set_dtype(federated_language::DataType::DT_INT64); tensor_type->add_dims(1); return type; } @@ -224,10 +226,11 @@ MATCHER(TensorsProtoEqual, namespace intrinsic { -#define INTRINSIC_FUNC(name, uri) \ - inline v0::Value name(std::optional type_spec = \ - std::nullopt) { \ - return IntrinsicV(#uri, type_spec); \ +#define INTRINSIC_FUNC(name, uri) \ + inline v0::Value name( \ + std::optional type_spec = \ + std::nullopt) { \ + return IntrinsicV(#uri, type_spec); \ } INTRINSIC_FUNC(ArgsIntoSequenceV, args_into_sequence); @@ -247,37 +250,38 @@ INTRINSIC_FUNC(FederatedZipAtServerV, federated_zip_at_server); } // namespace intrinsic -inline v0::Value ComputationV(v0::Computation computation_pb) { +inline v0::Value ComputationV(federated_language::Computation computation_pb) { v0::Value value_pb; *value_pb.mutable_computation() = computation_pb; return value_pb; } -inline v0::Computation SelectionComputation(v0::Computation source_pb, - int32_t index) { - v0::Computation computation_pb; - v0::Selection* selection_pb = computation_pb.mutable_selection(); +inline federated_language::Computation SelectionComputation( + federated_language::Computation source_pb, int32_t index) { + federated_language::Computation computation_pb; + federated_language::Selection* selection_pb = + computation_pb.mutable_selection(); *selection_pb->mutable_source() = source_pb; selection_pb->set_index(index); return computation_pb; } -inline v0::Computation StructComputation( - std::vector elements) { - v0::Computation computation_pb; - v0::Struct* struct_pb = computation_pb.mutable_struct_(); +inline federated_language::Computation StructComputation( + std::vector elements) { + federated_language::Computation computation_pb; + federated_language::Struct* struct_pb = computation_pb.mutable_struct_(); for (const auto& element : elements) { - v0::Struct::Element* element_pb = struct_pb->add_element(); + federated_language::Struct::Element* element_pb = struct_pb->add_element(); *element_pb->mutable_value() = element; } return computation_pb; } -inline v0::Computation LambdaComputation( +inline federated_language::Computation LambdaComputation( std::optional parameter_name, - v0::Computation result_computation_value) { - v0::Computation computation_pb; - v0::Lambda* lambda_pb = computation_pb.mutable_lambda(); + federated_language::Computation result_computation_value) { + federated_language::Computation computation_pb; + federated_language::Lambda* lambda_pb = computation_pb.mutable_lambda(); if (parameter_name != std::nullopt) { lambda_pb->mutable_parameter_name()->assign(parameter_name.value().data(), parameter_name.value().size()); @@ -286,13 +290,14 @@ inline v0::Computation LambdaComputation( return computation_pb; } -inline v0::Computation BlockComputation( - std::vector> locals, - v0::Computation result) { - v0::Computation computation_pb; - v0::Block* block_pb = computation_pb.mutable_block(); +inline federated_language::Computation BlockComputation( + std::vector> + locals, + federated_language::Computation result) { + federated_language::Computation computation_pb; + federated_language::Block* block_pb = computation_pb.mutable_block(); for (const auto& local : locals) { - v0::Block::Local* new_local_pb = block_pb->add_local(); + federated_language::Block::Local* new_local_pb = block_pb->add_local(); *new_local_pb->mutable_name() = std::get<0>(local); *new_local_pb->mutable_value() = std::get<1>(local); } @@ -300,35 +305,39 @@ inline v0::Computation BlockComputation( return computation_pb; } -inline v0::Computation ReferenceComputation(std::string_view reference_name) { - v0::Computation computation_pb; +inline federated_language::Computation ReferenceComputation( + std::string_view reference_name) { + federated_language::Computation computation_pb; computation_pb.mutable_reference()->mutable_name()->assign( reference_name.data(), reference_name.size()); return computation_pb; } -inline v0::Computation IntrinsicComputation(std::string_view uri) { - v0::Computation computation_pb; +inline federated_language::Computation IntrinsicComputation( + std::string_view uri) { + federated_language::Computation computation_pb; computation_pb.mutable_intrinsic()->mutable_uri()->assign(uri.data(), uri.size()); return computation_pb; } -inline v0::Computation DataComputation(std::string_view uri) { - v0::Computation computation_pb; +inline federated_language::Computation DataComputation(std::string_view uri) { + federated_language::Computation computation_pb; computation_pb.mutable_data()->mutable_uri()->assign(uri.data(), uri.size()); return computation_pb; } -inline v0::Computation PlacementComputation(std::string_view uri) { - v0::Computation computation_pb; +inline federated_language::Computation PlacementComputation( + std::string_view uri) { + federated_language::Computation computation_pb; computation_pb.mutable_placement()->mutable_uri()->assign(uri.data(), uri.size()); return computation_pb; } -inline v0::Computation LiteralComputation(v0::Array array_pb) { - v0::Computation computation_pb; +inline federated_language::Computation LiteralComputation( + federated_language::Array array_pb) { + federated_language::Computation computation_pb; computation_pb.mutable_literal()->mutable_value()->Swap(&array_pb); return computation_pb; } diff --git a/tensorflow_federated/cc/core/impl/executors/value_validation.cc b/tensorflow_federated/cc/core/impl/executors/value_validation.cc index b0e7e98e3f..f1300742fb 100644 --- a/tensorflow_federated/cc/core/impl/executors/value_validation.cc +++ b/tensorflow_federated/cc/core/impl/executors/value_validation.cc @@ -21,8 +21,8 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/cardinalities.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { diff --git a/tensorflow_federated/cc/core/impl/executors/value_validation_test.cc b/tensorflow_federated/cc/core/impl/executors/value_validation_test.cc index d1731691b1..b2b8ae3f19 100644 --- a/tensorflow_federated/cc/core/impl/executors/value_validation_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/value_validation_test.cc @@ -22,9 +22,9 @@ limitations under the License #include "googletest/include/gtest/gtest.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" #include "tensorflow_federated/proto/v0/executor.pb.h" namespace tensorflow_federated { @@ -74,7 +74,7 @@ TEST_F(ValueValidationTest, ValidateFederatedErrorOnWrongNumberClients) { TEST_F(ValueValidationTest, ValidateFederatedErrorOnNonAllEqualServer) { v0::Value value_proto; - v0::FederatedType* type_proto = + federated_language::FederatedType* type_proto = value_proto.mutable_federated()->mutable_type(); type_proto->set_all_equal(false); *type_proto->mutable_placement()->mutable_value()->mutable_uri() = "server"; diff --git a/tensorflow_federated/cc/core/impl/executors/xla_executor.cc b/tensorflow_federated/cc/core/impl/executors/xla_executor.cc index 757c8e8d3b..dccc9a46c6 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_executor.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_executor.cc @@ -31,6 +31,7 @@ limitations under the License #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "federated_language/proto/computation.pb.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/client.h" @@ -53,7 +54,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" #include "tensorflow_federated/cc/core/impl/executors/threading.h" #include "tensorflow_federated/cc/core/impl/executors/xla_utils.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" // clang-format off // In TF 2.17 MultiPlatformManager was renamed to PlatformManager. Remove @@ -99,17 +99,20 @@ class ServiceTensor { class Computation { public: Computation(xla::ExecutionHandle&& compiled_computation, - v0::Xla::Binding arg_binding, v0::Xla::Binding result_binding, - v0::Type computation_type) + federated_language::Xla::Binding arg_binding, + federated_language::Xla::Binding result_binding, + federated_language::Type computation_type) : xla_computation_(std::move(compiled_computation)), arg_binding_(std::move(arg_binding)), result_binding_(std::move(result_binding)), computation_type_(std::move(computation_type)) {} const xla::ExecutionHandle& xla_computation() { return xla_computation_; } - const v0::Xla::Binding& arg_binding() { return arg_binding_; } - const v0::Xla::Binding& result_binding() { return result_binding_; } - const v0::Type& type() { return computation_type_; } + const federated_language::Xla::Binding& arg_binding() { return arg_binding_; } + const federated_language::Xla::Binding& result_binding() { + return result_binding_; + } + const federated_language::Type& type() { return computation_type_; } private: Computation() = delete; @@ -120,11 +123,12 @@ class Computation { // TFF-JAX Python API to support unknown shapes and ranks in parameter // tensors, this assumption will need to be relaxed for these cases. // One option might involve preserving the proto and recompiling on the - // fly, adding to an internal cache of v0::Type to xla::ExecutionHandles. + // fly, adding to an internal cache of federated_language::Type to + // xla::ExecutionHandles. const xla::ExecutionHandle xla_computation_; - const v0::Xla::Binding arg_binding_; - const v0::Xla::Binding result_binding_; - const v0::Type computation_type_; + const federated_language::Xla::Binding arg_binding_; + const federated_language::Xla::Binding result_binding_; + const federated_language::Type computation_type_; }; // Representation for values embedded in the XLA executor. Generally, this class @@ -186,10 +190,11 @@ class XLAExecutorValue { template ::value_type> absl::Status PopulateFlatVectorLikeBinding( - const v0::Type& type, const v0::Xla::Binding& binding, F processing_fn, + const federated_language::Type& type, + const federated_language::Xla::Binding& binding, F processing_fn, std::vector* vector_to_populate) { switch (type.type_case()) { - case v0::Type::kTensor: { + case federated_language::Type::kTensor: { if (!binding.has_tensor()) { return absl::InvalidArgumentError( "Mismatch between tensor type and non-tensor binding while " @@ -201,7 +206,7 @@ absl::Status PopulateFlatVectorLikeBinding( TFF_TRY(processing_fn(type.tensor())); return absl::OkStatus(); } - case v0::Type::kStruct: { + case federated_language::Type::kStruct: { if (!binding.has_struct_()) { return absl::InvalidArgumentError( "Mismatch between struct type and non-struct binding while " @@ -240,21 +245,21 @@ absl::Status PopulateFlatVectorLikeBinding( // Returns a vector of TFF TensorTypes which correspond to the vector of tensors // specified by the binding argument. -absl::Status FlattenTypeToTensors(const v0::Type& type, - const v0::Xla::Binding& binding, - std::vector* tensor_vector) { - auto identity = - [](const v0::TensorType& x) -> absl::StatusOr { - return x; - }; +absl::Status FlattenTypeToTensors( + const federated_language::Type& type, + const federated_language::Xla::Binding& binding, + std::vector* tensor_vector) { + auto identity = [](const federated_language::TensorType& x) + -> absl::StatusOr { return x; }; return PopulateFlatVectorLikeBinding(type, binding, identity, tensor_vector); } // Computes vector of xla::Shape pointers from the type argument, in flattened // order determined by the binding argument. Populated in flat_shapes. -absl::Status ComputeFlatShapesFromType(const v0::Type& type, - const v0::Xla::Binding& binding, - std::vector* flat_shapes) { +absl::Status ComputeFlatShapesFromType( + const federated_language::Type& type, + const federated_language::Xla::Binding& binding, + std::vector* flat_shapes) { return PopulateFlatVectorLikeBinding(type, binding, ShapeFromTensorType, flat_shapes); } @@ -262,19 +267,21 @@ absl::Status ComputeFlatShapesFromType(const v0::Type& type, // Computes the number of tensor elements in a given binding. We interpret an // unset binding to contain 0 elements, for uniformity of handling unset // parameter bindings. -int ComputeNumElementsFromBinding(const v0::Xla::Binding& binding) { +int ComputeNumElementsFromBinding( + const federated_language::Xla::Binding& binding) { switch (binding.binding_case()) { - case v0::Xla::Binding::kTensor: { + case federated_language::Xla::Binding::kTensor: { return 1; } - case v0::Xla::Binding::kStruct: { + case federated_language::Xla::Binding::kStruct: { int num_elements = 0; - for (const v0::Xla::Binding& el_binding : binding.struct_().element()) { + for (const federated_language::Xla::Binding& el_binding : + binding.struct_().element()) { num_elements += ComputeNumElementsFromBinding(el_binding); } return num_elements; } - case v0::Xla::Binding::BINDING_NOT_SET: { + case federated_language::Xla::Binding::BINDING_NOT_SET: { return 0; } } @@ -286,10 +293,10 @@ int ComputeNumElementsFromBinding(const v0::Xla::Binding& binding) { // indices present in the binding argument can be assigned directly to their // appropriate locations. absl::Status FlattenValuesIntoBinding( - const v0::Xla::Binding& binding, const XLAExecutorValue& value, - std::vector& flat_vector) { + const federated_language::Xla::Binding& binding, + const XLAExecutorValue& value, std::vector& flat_vector) { switch (binding.binding_case()) { - case v0::Xla::Binding::kTensor: { + case federated_language::Xla::Binding::kTensor: { int32_t tensor_index_in_vector = binding.tensor().index(); if (value.type() != XLAExecutorValue::ValueType::TENSOR) { return absl::InvalidArgumentError(absl::StrCat( @@ -304,7 +311,7 @@ absl::Status FlattenValuesIntoBinding( flat_vector[tensor_index_in_vector] = tensor_data; return absl::OkStatus(); } - case v0::Xla::Binding::kStruct: { + case federated_language::Xla::Binding::kStruct: { if (value.type() != XLAExecutorValue::ValueType::STRUCT) { return absl::InvalidArgumentError(absl::StrCat( "Error encountered in FlattenValuesIntoBinding; encountered struct " @@ -337,17 +344,18 @@ absl::Status FlattenValuesIntoBinding( // binding argument. This function is conceptually the inverse of the above. absl::StatusOr PackageFlatValuesAsBinding( const std::vector& flat_tensor_values, - const v0::Xla::Binding& binding) { + const federated_language::Xla::Binding& binding) { switch (binding.binding_case()) { - case v0::Xla::Binding::kTensor: { + case federated_language::Xla::Binding::kTensor: { // Simply return the (tensor) XLAExecutorValue at the index indicated by // the binding. return flat_tensor_values[binding.tensor().index()]; } - case v0::Xla::Binding::kStruct: { + case federated_language::Xla::Binding::kStruct: { std::vector struct_element_values; struct_element_values.reserve(binding.struct_().element_size()); - for (const v0::Xla::Binding& el_binding : binding.struct_().element()) { + for (const federated_language::Xla::Binding& el_binding : + binding.struct_().element()) { struct_element_values.emplace_back(TFF_TRY( PackageFlatValuesAsBinding(flat_tensor_values, el_binding))); } @@ -472,9 +480,9 @@ class XLAExecutor : public ExecutorBase { } absl::StatusOr CreateValueComputation( - const v0::Computation& comp_pb) { + const federated_language::Computation& comp_pb) { switch (comp_pb.computation_case()) { - case v0::Computation::kXla: { + case federated_language::Computation::kXla: { if (!comp_pb.type().has_function()) { return absl::InvalidArgumentError( absl::StrCat("Computation proto with non-functional type " @@ -486,7 +494,8 @@ class XLAExecutor : public ExecutorBase { xla::XlaComputation xla_comp(std::move(hlo_proto)); // Compute the vector of flat arg shapes; these will be needed to // compile the computation. - v0::Xla::Binding arg_binding = comp_pb.xla().parameter(); + federated_language::Xla::Binding arg_binding = + comp_pb.xla().parameter(); int num_arg_elements = ComputeNumElementsFromBinding(arg_binding); // Preallocate this vector to num_arg_elements, so that we can // assign to these elements directly in the function call below. @@ -506,12 +515,13 @@ class XLAExecutor : public ExecutorBase { } // Finally, construct the representation of this computation in the // XLA executor. - v0::Xla::Binding result_binding = comp_pb.xla().result(); + federated_language::Xla::Binding result_binding = + comp_pb.xla().result(); return XLAExecutorValue(std::make_shared( std::move(*computation_handle), arg_binding, result_binding, comp_pb.type())); } - case v0::Computation::kLiteral: { + case federated_language::Computation::kLiteral: { absl::StatusOr> data = xla_client_->TransferToServer( TFF_TRY(LiteralFromArray(comp_pb.literal().value()))); @@ -625,9 +635,10 @@ class XLAExecutor : public ExecutorBase { absl::StrCat("Error calling XLA computation. Message: ", result.status().message())); } - const v0::Xla::Binding& result_binding = fn->result_binding(); + const federated_language::Xla::Binding& result_binding = + fn->result_binding(); switch (result_binding.binding_case()) { - case v0::Xla::Binding::kTensor: { + case federated_language::Xla::Binding::kTensor: { // JAX tracing always compiles results to be tuples, which would // result in length 1 tuples. absl::StatusOr>> @@ -648,7 +659,7 @@ class XLAExecutor : public ExecutorBase { TFF_TRY(PrimitiveTypeFromDataType( fn->type().function().result().tensor().dtype()))); } - case v0::Xla::Binding::kStruct: { + case federated_language::Xla::Binding::kStruct: { const int num_result_elements = ComputeNumElementsFromBinding(result_binding); std::vector> global_data_vector; @@ -670,7 +681,8 @@ class XLAExecutor : public ExecutorBase { std::vector flat_value_vector; // Preallocate the flat types tensor as required to assign directly to // its elements. - std::vector flat_tensor_types(num_result_elements); + std::vector flat_tensor_types( + num_result_elements); TFF_TRY(FlattenTypeToTensors(fn->type().function().result(), result_binding, &flat_tensor_types)); flat_value_vector.reserve(flat_tensor_types.size()); diff --git a/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc b/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc index 7150bad17a..b3ec3cfd25 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc @@ -30,6 +30,9 @@ limitations under the License #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -50,9 +53,6 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/value_test_utils.h" #include "tensorflow_federated/cc/testing/protobuf_matchers.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" ABSL_FLAG(std::string, tff_xla_executor_test_platform, "Host", "The name of the XLA platform to run the tests on. By default will " @@ -73,22 +73,22 @@ using ::tensorflow_federated::testing::TensorT; using ::tensorflow_federated::testing::TensorV; using ::testing::HasSubstr; -inline absl::StatusOr> BindingFromType( - v0::Type type, int next_unused_index) { +inline absl::StatusOr> +BindingFromType(federated_language::Type type, int next_unused_index) { switch (type.type_case()) { - case v0::Type::kTensor: { - v0::Xla::Binding binding; + case federated_language::Type::kTensor: { + federated_language::Xla::Binding binding; binding.mutable_tensor()->set_index(next_unused_index); return std::make_tuple(binding, next_unused_index + 1); } - case v0::Type::kStruct: { - v0::Xla::Binding binding; + case federated_language::Type::kStruct: { + federated_language::Xla::Binding binding; for (const auto& type_element : type.struct_().element()) { auto partial_binding = TFF_TRY(BindingFromType(type_element.value(), next_unused_index)); next_unused_index = std::get(partial_binding); *binding.mutable_struct_()->add_element() = - std::get(partial_binding); + std::get(partial_binding); } return std::make_tuple(binding, next_unused_index); } @@ -99,12 +99,12 @@ inline absl::StatusOr> BindingFromType( } } -inline v0::Value ComputationV(std::optional in_binding, - v0::Xla::Binding out_binding, - xla::XlaComputation xla_comp, - v0::Type computation_type) { +inline v0::Value ComputationV( + std::optional in_binding, + federated_language::Xla::Binding out_binding, xla::XlaComputation xla_comp, + federated_language::Type computation_type) { v0::Value value_pb; - v0::Computation* comp_pb = value_pb.mutable_computation(); + federated_language::Computation* comp_pb = value_pb.mutable_computation(); comp_pb->mutable_xla()->mutable_hlo_module()->PackFrom(xla_comp.proto()); *comp_pb->mutable_type() = computation_type; if (in_binding.has_value()) { @@ -317,8 +317,9 @@ TEST_F(XLAExecutorTest, CreateValueComputationTensorNonFunctionalTypeFails) { xla::Parameter(&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x"); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type float_tensor_type; - float_tensor_type.mutable_tensor()->set_dtype(v0::DataType::DT_FLOAT); + federated_language::Type float_tensor_type; + float_tensor_type.mutable_tensor()->set_dtype( + federated_language::DataType::DT_FLOAT); auto tensor_binding = std::get<0>(TFF_ASSERT_OK(BindingFromType(float_tensor_type, 0))); @@ -344,13 +345,14 @@ TEST_F(XLAExecutorTest, xla::Parameter(&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x"); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type float_tensor; - float_tensor.mutable_tensor()->set_dtype(v0::DataType::DT_FLOAT); - v0::Type function_type; + federated_language::Type float_tensor; + float_tensor.mutable_tensor()->set_dtype( + federated_language::DataType::DT_FLOAT); + federated_language::Type function_type; *function_type.mutable_function()->mutable_result() = float_tensor; *function_type.mutable_function()->mutable_parameter() = float_tensor; // We create a binding with mismatched structure. - v0::Type struct_type; + federated_language::Type struct_type; *struct_type.mutable_struct_()->add_element()->mutable_value() = float_tensor; auto struct_binding = @@ -378,12 +380,13 @@ TEST_F(XLAExecutorTest, XLAShapeWithUnknownDims(tensorflow::DT_FLOAT, num_dims), "x"); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type float_unk_shape_tensor; - float_unk_shape_tensor.mutable_tensor()->set_dtype(v0::DataType::DT_FLOAT); + federated_language::Type float_unk_shape_tensor; + float_unk_shape_tensor.mutable_tensor()->set_dtype( + federated_language::DataType::DT_FLOAT); for (int i = 0; i < num_dims; i++) { float_unk_shape_tensor.mutable_tensor()->add_dims(-1); } - v0::Type function_type; + federated_language::Type function_type; *function_type.mutable_function()->mutable_result() = float_unk_shape_tensor; *function_type.mutable_function()->mutable_parameter() = float_unk_shape_tensor; @@ -404,10 +407,11 @@ TEST_F(XLAExecutorTest, CreateValueComputationTensorParameterUnknownRankFails) { "x"); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type float_unk_rank_tensor; - float_unk_rank_tensor.mutable_tensor()->set_dtype(v0::DataType::DT_FLOAT); + federated_language::Type float_unk_rank_tensor; + float_unk_rank_tensor.mutable_tensor()->set_dtype( + federated_language::DataType::DT_FLOAT); float_unk_rank_tensor.mutable_tensor()->set_unknown_rank(true); - v0::Type function_type; + federated_language::Type function_type; *function_type.mutable_function()->mutable_result() = float_unk_rank_tensor; *function_type.mutable_function()->mutable_parameter() = float_unk_rank_tensor; @@ -423,12 +427,14 @@ TEST_F(XLAExecutorTest, CreateValueComputationTensorParameterUnknownRankFails) { } TEST_F(XLAExecutorTest, CreateValueComputationLiteralReturnsResult) { - const v0::DataType dtype = v0::DataType::DT_INT32; - v0::ArrayShape shape_pb = testing::CreateArrayShape({}); + const federated_language::DataType dtype = + federated_language::DataType::DT_INT32; + federated_language::ArrayShape shape_pb = testing::CreateArrayShape({}); auto values = {1}; - v0::Array array_pb = + federated_language::Array array_pb = TFF_ASSERT_OK(testing::CreateArray(dtype, shape_pb, values)); - v0::Computation computation_pb = testing::LiteralComputation(array_pb); + federated_language::Computation computation_pb = + testing::LiteralComputation(array_pb); v0::Value value_pb = testing::ComputationV(computation_pb); const OwnedValueId& embedded_fn = @@ -447,8 +453,8 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallSingleTensor) { xla::Tuple(&builder, {constant}); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - auto tensor_type = TensorT(v0::DataType::DT_FLOAT); - v0::Type function_type = NoArgFunctionT(tensor_type); + auto tensor_type = TensorT(federated_language::DataType::DT_FLOAT); + federated_language::Type function_type = NoArgFunctionT(tensor_type); v0::Value computation = ComputationV( std::nullopt, std::get<0>(TFF_ASSERT_OK(BindingFromType(tensor_type, 0))), std::move(*xla_computation), function_type); @@ -471,8 +477,9 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallTensorStructure) { absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type return_type = FlatStructT(v0::DataType::DT_FLOAT, 2); - v0::Type function_type = NoArgFunctionT(return_type); + federated_language::Type return_type = + FlatStructT(federated_language::DataType::DT_FLOAT, 2); + federated_language::Type function_type = NoArgFunctionT(return_type); v0::Value computation = ComputationV( std::nullopt, std::get<0>(TFF_ASSERT_OK(BindingFromType(return_type, 0))), @@ -497,8 +504,9 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallNestedTensorStructure) { ASSERT_TRUE(xla_computation.ok()); // We construct a return type > - v0::Type nested_struct_type = NestedStructT(v0::DataType::DT_FLOAT); - v0::Type function_type = NoArgFunctionT(nested_struct_type); + federated_language::Type nested_struct_type = + NestedStructT(federated_language::DataType::DT_FLOAT); + federated_language::Type function_type = NoArgFunctionT(nested_struct_type); v0::Value computation = ComputationV( std::nullopt, @@ -525,8 +533,9 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityScalar) { xla::Tuple(&builder, {parameter}); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type float_tensor_type = TensorT(v0::DataType::DT_FLOAT); - v0::Type function_type = IdentityFunctionT(float_tensor_type); + federated_language::Type float_tensor_type = + TensorT(federated_language::DataType::DT_FLOAT); + federated_language::Type function_type = IdentityFunctionT(float_tensor_type); auto binding = std::get<0>(TFF_ASSERT_OK(BindingFromType(float_tensor_type, 0))); v0::Value computation = ComputationV( @@ -553,9 +562,10 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeIdentitySingletonStruct) { xla::Tuple(&builder, {parameter}); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type single_float_struct_type = - StructT({TensorT(v0::DataType::DT_FLOAT)}); - v0::Type function_type = IdentityFunctionT(single_float_struct_type); + federated_language::Type single_float_struct_type = + StructT({TensorT(federated_language::DataType::DT_FLOAT)}); + federated_language::Type function_type = + IdentityFunctionT(single_float_struct_type); auto binding = std::get<0>(TFF_ASSERT_OK(BindingFromType(single_float_struct_type, 0))); v0::Value computation = ComputationV( @@ -585,8 +595,10 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityNestedStruct) { absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type nested_struct_type = NestedStructT(v0::DataType::DT_FLOAT); - v0::Type function_type = IdentityFunctionT(nested_struct_type); + federated_language::Type nested_struct_type = + NestedStructT(federated_language::DataType::DT_FLOAT); + federated_language::Type function_type = + IdentityFunctionT(nested_struct_type); auto binding = std::get<0>(TFF_ASSERT_OK(BindingFromType(nested_struct_type, 0))); v0::Value computation = ComputationV( @@ -618,10 +630,12 @@ TEST_F(XLAExecutorTest, CallAndMaterializeIdentityPartiallyNonScalarStruct) { ASSERT_TRUE(xla_computation.ok()); // Create a computation type to match the above. - v0::Type scalar = TensorT(v0::DataType::DT_FLOAT); - v0::Type matrix = TensorT(v0::DataType::DT_FLOAT, {10, 10}); - v0::Type struct_type = StructT({scalar, matrix}); - v0::Type function_type = IdentityFunctionT(struct_type); + federated_language::Type scalar = + TensorT(federated_language::DataType::DT_FLOAT); + federated_language::Type matrix = + TensorT(federated_language::DataType::DT_FLOAT, {10, 10}); + federated_language::Type struct_type = StructT({scalar, matrix}); + federated_language::Type function_type = IdentityFunctionT(struct_type); auto binding = std::get<0>(TFF_ASSERT_OK(BindingFromType(struct_type, 0))); v0::Value computation = ComputationV( binding, binding, std::move(*xla_computation), function_type); @@ -650,9 +664,12 @@ TEST_F(XLAExecutorTest, xla::Tuple(&builder, {x, xla::Add(y, z)}); absl::StatusOr xla_computation = builder.Build(); ASSERT_TRUE(xla_computation.ok()); - v0::Type nested_struct_type = NestedStructT(v0::DataType::DT_FLOAT); - v0::Type result_type = FlatStructT(v0::DataType::DT_FLOAT, 2); - v0::Type function_type = FunctionT(nested_struct_type, result_type); + federated_language::Type nested_struct_type = + NestedStructT(federated_language::DataType::DT_FLOAT); + federated_language::Type result_type = + FlatStructT(federated_language::DataType::DT_FLOAT, 2); + federated_language::Type function_type = + FunctionT(nested_struct_type, result_type); auto parameter_binding = std::get<0>(TFF_ASSERT_OK(BindingFromType(nested_struct_type, 0))); auto result_binding = diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils.cc b/tensorflow_federated/cc/core/impl/executors/xla_utils.cc index 06da07e101..51d3c3df0d 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils.cc @@ -22,50 +22,50 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" namespace tensorflow_federated { absl::StatusOr PrimitiveTypeFromDataType( - const v0::DataType data_type) { + const federated_language::DataType data_type) { switch (data_type) { - case v0::DataType::DT_BOOL: + case federated_language::DataType::DT_BOOL: return xla::PRED; - case v0::DataType::DT_INT8: + case federated_language::DataType::DT_INT8: return xla::S8; - case v0::DataType::DT_INT16: + case federated_language::DataType::DT_INT16: return xla::S16; - case v0::DataType::DT_INT32: + case federated_language::DataType::DT_INT32: return xla::S32; - case v0::DataType::DT_INT64: + case federated_language::DataType::DT_INT64: return xla::S64; - case v0::DataType::DT_UINT8: + case federated_language::DataType::DT_UINT8: return xla::U8; - case v0::DataType::DT_UINT16: + case federated_language::DataType::DT_UINT16: return xla::U16; - case v0::DataType::DT_UINT32: + case federated_language::DataType::DT_UINT32: return xla::U32; - case v0::DataType::DT_UINT64: + case federated_language::DataType::DT_UINT64: return xla::U64; - case v0::DataType::DT_HALF: + case federated_language::DataType::DT_HALF: return xla::F16; - case v0::DataType::DT_FLOAT: + case federated_language::DataType::DT_FLOAT: return xla::F32; - case v0::DataType::DT_DOUBLE: + case federated_language::DataType::DT_DOUBLE: return xla::F64; - case v0::DataType::DT_COMPLEX64: + case federated_language::DataType::DT_COMPLEX64: return xla::C64; - case v0::DataType::DT_COMPLEX128: + case federated_language::DataType::DT_COMPLEX128: return xla::C128; - case v0::DataType::DT_BFLOAT16: + case federated_language::DataType::DT_BFLOAT16: return xla::BF16; default: return absl::UnimplementedError( @@ -74,7 +74,7 @@ absl::StatusOr PrimitiveTypeFromDataType( } absl::StatusOr ShapeFromTensorType( - const v0::TensorType& tensor_type_pb) { + const federated_language::TensorType& tensor_type_pb) { if (tensor_type_pb.unknown_rank()) { return absl::InvalidArgumentError( "Shapes of unknown rank are not supported in the XLA executor."); @@ -84,8 +84,9 @@ absl::StatusOr ShapeFromTensorType( tensor_type_pb.dims()); } -absl::StatusOr ShapeFromArrayShape(v0::DataType data_type, - const v0::ArrayShape& shape_pb) { +absl::StatusOr ShapeFromArrayShape( + federated_language::DataType data_type, + const federated_language::ArrayShape& shape_pb) { if (shape_pb.unknown_rank()) { return absl::InvalidArgumentError( "Shapes of unknown rank are not supported in the XLA executor."); @@ -140,109 +141,110 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField }); } -absl::StatusOr LiteralFromArray(const v0::Array& array_pb) { +absl::StatusOr LiteralFromArray( + const federated_language::Array& array_pb) { switch (array_pb.kind_case()) { - case v0::Array::kBoolList: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_BOOL, array_pb.shape()))); + case federated_language::Array::kBoolList: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_BOOL, array_pb.shape()))); CopyFromRepeatedField(array_pb.bool_list().value(), literal.data().begin()); return literal; } - case v0::Array::kInt8List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_INT8, array_pb.shape()))); + case federated_language::Array::kInt8List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_INT8, array_pb.shape()))); CopyFromRepeatedField(array_pb.int8_list().value(), literal.data().begin()); return literal; } - case v0::Array::kInt16List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_INT16, array_pb.shape()))); + case federated_language::Array::kInt16List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_INT16, array_pb.shape()))); CopyFromRepeatedField(array_pb.int16_list().value(), literal.data().begin()); return literal; } - case v0::Array::kInt32List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_INT32, array_pb.shape()))); + case federated_language::Array::kInt32List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_INT32, array_pb.shape()))); CopyFromRepeatedField(array_pb.int32_list().value(), literal.data().begin()); return literal; } - case v0::Array::kInt64List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_INT64, array_pb.shape()))); + case federated_language::Array::kInt64List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_INT64, array_pb.shape()))); CopyFromRepeatedField(array_pb.int64_list().value(), literal.data().begin()); return literal; } - case v0::Array::kUint8List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_UINT8, array_pb.shape()))); + case federated_language::Array::kUint8List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_UINT8, array_pb.shape()))); CopyFromRepeatedField(array_pb.uint8_list().value(), literal.data().begin()); return literal; } - case v0::Array::kUint16List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_UINT16, array_pb.shape()))); + case federated_language::Array::kUint16List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_UINT16, array_pb.shape()))); CopyFromRepeatedField(array_pb.uint16_list().value(), literal.data().begin()); return literal; } - case v0::Array::kUint32List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_UINT32, array_pb.shape()))); + case federated_language::Array::kUint32List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_UINT32, array_pb.shape()))); CopyFromRepeatedField(array_pb.uint32_list().value(), literal.data().begin()); return literal; } - case v0::Array::kUint64List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_UINT64, array_pb.shape()))); + case federated_language::Array::kUint64List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_UINT64, array_pb.shape()))); CopyFromRepeatedField(array_pb.uint64_list().value(), literal.data().begin()); return literal; } - case v0::Array::kFloat16List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_HALF, array_pb.shape()))); + case federated_language::Array::kFloat16List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_HALF, array_pb.shape()))); CopyFromRepeatedField(array_pb.float16_list().value(), literal.data().begin()); return literal; } - case v0::Array::kFloat32List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_FLOAT, array_pb.shape()))); + case federated_language::Array::kFloat32List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_FLOAT, array_pb.shape()))); CopyFromRepeatedField(array_pb.float32_list().value(), literal.data().begin()); return literal; } - case v0::Array::kFloat64List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_DOUBLE, array_pb.shape()))); + case federated_language::Array::kFloat64List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_DOUBLE, array_pb.shape()))); CopyFromRepeatedField(array_pb.float64_list().value(), literal.data().begin()); return literal; } - case v0::Array::kComplex64List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_COMPLEX64, array_pb.shape()))); + case federated_language::Array::kComplex64List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_COMPLEX64, array_pb.shape()))); CopyFromRepeatedField(array_pb.complex64_list().value(), literal.data().begin()); return literal; } - case v0::Array::kComplex128List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_COMPLEX128, array_pb.shape()))); + case federated_language::Array::kComplex128List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_COMPLEX128, array_pb.shape()))); CopyFromRepeatedField(array_pb.complex128_list().value(), literal.data().begin()); return literal; } - case v0::Array::kBfloat16List: { - xla::Literal literal(TFF_TRY( - ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape()))); + case federated_language::Array::kBfloat16List: { + xla::Literal literal(TFF_TRY(ShapeFromArrayShape( + federated_language::DataType::DT_BFLOAT16, array_pb.shape()))); CopyFromRepeatedField(array_pb.bfloat16_list().value(), literal.data().begin()); return literal; diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils.h b/tensorflow_federated/cc/core/impl/executors/xla_utils.h index 7a77b90059..9acac52066 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils.h @@ -17,28 +17,30 @@ limitations under the License #define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_EXECUTORS_XLA_UTILS_H_ #include "absl/status/statusor.h" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" namespace tensorflow_federated { -// Creates a xla::PrimitiveType from a ::tensorflow_federated::v0::DataType. +// Creates a xla::PrimitiveType from a ::federated_language::DataType. absl::StatusOr PrimitiveTypeFromDataType( - v0::DataType data_type); + federated_language::DataType data_type); -// Creates a xla::Shape from a v0::TensorType. +// Creates a xla::Shape from a federated_language::TensorType. absl::StatusOr ShapeFromTensorType( - const v0::TensorType& tensor_type_pb); + const federated_language::TensorType& tensor_type_pb); -// Creates a xla::Shape from a v0::ArrayShape. -absl::StatusOr ShapeFromArrayShape(v0::DataType data_type, - const v0::ArrayShape& shape_pb); +// Creates a xla::Shape from a federated_language::ArrayShape. +absl::StatusOr ShapeFromArrayShape( + federated_language::DataType data_type, + const federated_language::ArrayShape& shape_pb); -// Creates a xla::Literal from a v0::Array. -absl::StatusOr LiteralFromArray(const v0::Array& array_pb); +// Creates a xla::Literal from a federated_language::Array. +absl::StatusOr LiteralFromArray( + const federated_language::Array& array_pb); } // namespace tensorflow_federated diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc b/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc index f9b00096fd..cd6c46318f 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc @@ -22,6 +22,9 @@ limitations under the License #include "absl/status/status.h" #include "absl/status/statusor.h" #include "third_party/eigen3/Eigen/Core" +#include "federated_language/proto/array.pb.h" +#include "federated_language/proto/computation.pb.h" +#include "federated_language/proto/data_type.pb.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -32,17 +35,14 @@ limitations under the License #include "tensorflow_federated/cc/core/impl/executors/array_shape_test_utils.h" #include "tensorflow_federated/cc/core/impl/executors/array_test_utils.h" #include "tensorflow_federated/cc/testing/status_matchers.h" -#include "tensorflow_federated/proto/v0/array.pb.h" -#include "tensorflow_federated/proto/v0/computation.pb.h" -#include "tensorflow_federated/proto/v0/data_type.pb.h" namespace tensorflow_federated { namespace { TEST(ShapeFromTensorTypeTest, TestReturnsShape_fully_defined) { std::initializer_list dims = {2, 3}; - v0::TensorType type_pb; - type_pb.set_dtype(v0::DataType::DT_INT32); + federated_language::TensorType type_pb; + type_pb.set_dtype(federated_language::DataType::DT_INT32); type_pb.mutable_dims()->Assign(dims.begin(), dims.end()); const xla::Shape& expected_shape = xla::ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {2, 3}); @@ -54,8 +54,8 @@ TEST(ShapeFromTensorTypeTest, TestReturnsShape_fully_defined) { TEST(ShapeFromTensorTypeTest, TestReturnsShape_scalar) { std::initializer_list dims = {}; - v0::TensorType type_pb; - type_pb.set_dtype(v0::DataType::DT_INT32); + federated_language::TensorType type_pb; + type_pb.set_dtype(federated_language::DataType::DT_INT32); type_pb.mutable_dims()->Assign(dims.begin(), dims.end()); const xla::Shape& expected_shape = xla::ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {}); @@ -67,8 +67,8 @@ TEST(ShapeFromTensorTypeTest, TestReturnsShape_scalar) { TEST(ShapeFromTensorTypeTest, TestFails_partially_defined) { std::initializer_list dims = {2, -1}; - v0::TensorType type_pb; - type_pb.set_dtype(v0::DataType::DT_INT32); + federated_language::TensorType type_pb; + type_pb.set_dtype(federated_language::DataType::DT_INT32); type_pb.mutable_dims()->Assign(dims.begin(), dims.end()); const absl::StatusOr& result = ShapeFromTensorType(type_pb); @@ -78,8 +78,8 @@ TEST(ShapeFromTensorTypeTest, TestFails_partially_defined) { TEST(ShapeFromTensorTypeTest, TestFails_unknown) { std::initializer_list dims = {}; - v0::TensorType type_pb; - type_pb.set_dtype(v0::DataType::DT_INT32); + federated_language::TensorType type_pb; + type_pb.set_dtype(federated_language::DataType::DT_INT32); type_pb.mutable_dims()->Assign(dims.begin(), dims.end()); type_pb.set_unknown_rank(true); @@ -89,48 +89,53 @@ TEST(ShapeFromTensorTypeTest, TestFails_unknown) { } TEST(ShapeFromArrayShapeTest, TestReturnsShape_fully_defined) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({2, 3}); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({2, 3}); const xla::Shape& expected_shape = xla::ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {2, 3}); - const xla::Shape& actual_shape = - TFF_ASSERT_OK(ShapeFromArrayShape(v0::DataType::DT_INT32, shape_pb)); + const xla::Shape& actual_shape = TFF_ASSERT_OK( + ShapeFromArrayShape(federated_language::DataType::DT_INT32, shape_pb)); EXPECT_TRUE(xla::Shape::Equal()(actual_shape, expected_shape)); } TEST(ShapeFromArrayShapeTest, TestReturnsShape_scalar) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({}); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({}); const xla::Shape& expected_shape = xla::ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {}); - const xla::Shape& actual_shape = - TFF_ASSERT_OK(ShapeFromArrayShape(v0::DataType::DT_INT32, shape_pb)); + const xla::Shape& actual_shape = TFF_ASSERT_OK( + ShapeFromArrayShape(federated_language::DataType::DT_INT32, shape_pb)); EXPECT_TRUE(xla::Shape::Equal()(actual_shape, expected_shape)); } TEST(ShapeFromArrayShapeTest, TestFails_partially_defined) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({2, -1}); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({2, -1}); const absl::StatusOr& result = - ShapeFromArrayShape(v0::DataType::DT_INT32, shape_pb); + ShapeFromArrayShape(federated_language::DataType::DT_INT32, shape_pb); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } TEST(ShapeFromArrayShapeTest, TestFails_unknown) { - const v0::ArrayShape& shape_pb = testing::CreateArrayShape({}, true); + const federated_language::ArrayShape& shape_pb = + testing::CreateArrayShape({}, true); const absl::StatusOr& result = - ShapeFromArrayShape(v0::DataType::DT_INT32, shape_pb); + ShapeFromArrayShape(federated_language::DataType::DT_INT32, shape_pb); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } TEST(LiteralFromArrayTest, TestReturnsLiteral_bool) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_BOOL, testing::CreateArrayShape({}), {true})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_BOOL, + testing::CreateArrayShape({}), {true})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -140,8 +145,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_bool) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_int8) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_INT8, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_INT8, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -151,8 +157,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_int8) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_int16) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_INT16, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_INT16, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -162,8 +169,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_int16) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_int32) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_INT32, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_INT32, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -173,8 +181,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_int32) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_int64) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_INT64, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_INT64, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -184,8 +193,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_int64) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_uint8) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_UINT8, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_UINT8, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -195,8 +205,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_uint8) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_uint16) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_UINT16, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_UINT16, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -206,8 +217,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_uint16) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_uint32) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_UINT32, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_UINT32, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -217,8 +229,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_uint32) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_uint64) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_UINT64, testing::CreateArrayShape({}), {1})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_UINT64, + testing::CreateArrayShape({}), {1})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -228,9 +241,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_uint64) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_float16) { - const v0::Array& array_pb = TFF_ASSERT_OK( - testing::CreateArray(v0::DataType::DT_HALF, testing::CreateArrayShape({}), - {Eigen::half{1.0}})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_HALF, + testing::CreateArrayShape({}), {Eigen::half{1.0}})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -240,8 +253,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float16) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_float32) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_FLOAT, + testing::CreateArrayShape({}), {1.0})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -251,8 +265,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float32) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_float64) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_DOUBLE, testing::CreateArrayShape({}), {1.0})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_DOUBLE, + testing::CreateArrayShape({}), {1.0})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -262,9 +277,10 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float64) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_complex64) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_COMPLEX64, testing::CreateArrayShape({}), - {std::complex{1.0, 1.0}})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray( + federated_language::DataType::DT_COMPLEX64, + testing::CreateArrayShape({}), {std::complex{1.0, 1.0}})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -275,9 +291,10 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_complex64) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_complex128) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_COMPLEX128, testing::CreateArrayShape({}), - {std::complex{1.0, 1.0}})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray( + federated_language::DataType::DT_COMPLEX128, + testing::CreateArrayShape({}), {std::complex{1.0, 1.0}})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -288,9 +305,10 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_complex128) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}), - {Eigen::bfloat16{1.0}})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray( + federated_language::DataType::DT_BFLOAT16, + testing::CreateArrayShape({}), {Eigen::bfloat16{1.0}})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -301,9 +319,10 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) { } TEST(LiteralFromArrayTest, TestReturnsLiteral_array) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_INT32, testing::CreateArrayShape({2, 3}), - {1, 2, 3, 4, 5, 6})); + const federated_language::Array& array_pb = + TFF_ASSERT_OK(testing::CreateArray(federated_language::DataType::DT_INT32, + testing::CreateArrayShape({2, 3}), + {1, 2, 3, 4, 5, 6})); const xla::Literal& actual_literal = TFF_ASSERT_OK(LiteralFromArray(array_pb)); @@ -314,8 +333,9 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_array) { } TEST(LiteralFromArrayTest, TestFails_string) { - const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( - v0::DataType::DT_STRING, testing::CreateArrayShape({}), {"a"})); + const federated_language::Array& array_pb = TFF_ASSERT_OK( + testing::CreateArray(federated_language::DataType::DT_STRING, + testing::CreateArrayShape({}), {"a"})); const absl::StatusOr& result = LiteralFromArray(array_pb); diff --git a/tensorflow_federated/proto/v0/BUILD b/tensorflow_federated/proto/v0/BUILD index e59dba5c78..18c87d39cd 100644 --- a/tensorflow_federated/proto/v0/BUILD +++ b/tensorflow_federated/proto/v0/BUILD @@ -16,63 +16,12 @@ py_library( visibility = ["//tools/python_package:python_package_tool"], ) -proto_library( - name = "array_proto", - srcs = ["array.proto"], - deps = [":data_type_proto"], -) - -py_proto_library( - name = "array_py_pb2", - deps = [":array_proto"], -) - -cc_proto_library( - name = "array_cc_proto", - deps = [":array_proto"], -) - -proto_library( - name = "computation_proto", - srcs = ["computation.proto"], - deps = [ - ":array_proto", - ":data_type_proto", - "@com_google_protobuf//:any_proto", - ], -) - -py_proto_library( - name = "computation_py_pb2", - deps = [":computation_proto"], -) - -cc_proto_library( - name = "computation_cc_proto", - deps = [":computation_proto"], -) - -proto_library( - name = "data_type_proto", - srcs = ["data_type.proto"], -) - -py_proto_library( - name = "data_type_py_pb2", - deps = [":data_type_proto"], -) - -cc_proto_library( - name = "data_type_cc_proto", - deps = [":data_type_proto"], -) - proto_library( name = "executor_proto", srcs = ["executor.proto"], deps = [ - ":array_proto", - ":computation_proto", + "@federated_language//federated_language/proto:array_proto", + "@federated_language//federated_language/proto:computation_proto", ], ) diff --git a/tensorflow_federated/proto/v0/array.proto b/tensorflow_federated/proto/v0/array.proto deleted file mode 100644 index aad5d34f7b..0000000000 --- a/tensorflow_federated/proto/v0/array.proto +++ /dev/null @@ -1,86 +0,0 @@ -syntax = "proto3"; - -package tensorflow_federated.v0; - -import "tensorflow_federated/proto/v0/data_type.proto"; - -// ArrayShape is the shape of an `Array`, and may be one of the following: -// -// * Fully-defined: Has a known number of dimensions and a known size for each -// dimension (e.g. dim=(2, 3)). -// * Partially-defined: Has a known number of dimensions, and an unknown size -// (repesented as size -1) for one or more dimension (e.g. dim=(2, -1)). -// * Unknown: Has an unknown number of dimensions (unknown_rank=True). -// * Scalar: Has no dimensions (dim=(), unknown_rank=False). -message ArrayShape { - repeated int64 dim = 1; - bool unknown_rank = 2; -} - -// Array is the native representation of an array. -// -// This protobuf resembles the equivalent TensorFlow -// (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto) -// and XLA (https://github.com/openxla/xla/blob/main/xla/xla_data.proto) -// constructs and adopts some design decisions made by those platforms -// in order to reduce the complexity and cost of converting values between those -// environments, including: -// -// * How dtypes are packed. -// * Which dtypes are supported. -// * How strings are represented. -message Array { - DataType dtype = 1; - ArrayShape shape = 2; - - // Serialized raw array content. This representation can be used for all array - // types. The purpose of this representation is to reduce serialization - // overhead during RPC call by avoiding serialization of many repeated small - // items. - optional bytes content = 18; - - message BoolList { - repeated bool value = 1; - } - // INT8, INT16, INT32, UINT8, UINT16, HALF, BFLOAT16 - message IntList { - repeated int32 value = 1; - } - message Int64List { - repeated int64 value = 1; - } - message Uint32List { - repeated uint32 value = 1; - } - message Uint64List { - repeated uint64 value = 1; - } - message FloatList { - repeated float value = 1; - } - message DoubleList { - repeated double value = 1; - } - message BytesList { - repeated bytes value = 1; - } - - oneof kind { - BoolList bool_list = 3; - IntList int8_list = 4; - IntList int16_list = 5; - IntList int32_list = 6; - Int64List int64_list = 7; - IntList uint8_list = 8; - IntList uint16_list = 9; - Uint32List uint32_list = 10; - Uint64List uint64_list = 11; - IntList float16_list = 12; - FloatList float32_list = 13; - DoubleList float64_list = 14; - FloatList complex64_list = 15; - DoubleList complex128_list = 16; - IntList bfloat16_list = 19; - BytesList string_list = 17; - } -} diff --git a/tensorflow_federated/proto/v0/computation.proto b/tensorflow_federated/proto/v0/computation.proto deleted file mode 100644 index 080d3d9121..0000000000 --- a/tensorflow_federated/proto/v0/computation.proto +++ /dev/null @@ -1,937 +0,0 @@ -syntax = "proto3"; - -package tensorflow_federated.v0; - -import "google/protobuf/any.proto"; -import "tensorflow_federated/proto/v0/array.proto"; -import "tensorflow_federated/proto/v0/data_type.proto"; - -// A core data structure that contains a serialized representation of a unit of -// processing to perform by the TensorFlow Federated framework. This data -// structure is the primary unit of composition and the means by which we -// represent, store, and exchange federated computations and their constituents -// between system components. It is the lowest and smallest programmable -// abstraction layer that a range of higher-level APIs will be layered upon, -// structured around the minimum set of concepts and abstractions that provide -// a level of expressiveness sufficient to efficiently support current and -// anticipated uses. This layer is not intended for consumption by most users. -// -// In its most general sense, an instance of a Computation as defined here is -// simply an expression that produces a certain value. The structure of this -// expression, typically nested, determines how this value is intended to be -// computed (hence the term "computation"). We may use terms "expression" and -// "computation" interchangeably in this and other files, although technically, -// the term "computation" refers to a process, whereas "expression" refers to -// a specification of that process. -message Computation { - // The type of what's represented by this structure, which may be functional - // or non-functional. If it is a TensorFlow block or a lambda expression, - // the type will be functional. If it is a Struct, or a Call that returns a - // tensor or a Struct in the result, the type will be non-functional. - // - // A Call is a typical way to represent an invocation of a top-level federated - // computation with all its parameters fully specified. Thus, a top-level - // computation with all of its parameters filled in may have a non-functional - // type (the same as type of the result it computes). The illustrative - // example to think of is "(x -> x + 10)(20)", the type of which is an int, a - // non-functional type. If a top-level federated computation has all of its - // parameters filled in, it will assume a similar form. - Type type = 1; - - // The specification of the computation to perform. - // - // A hypothetical example of a federated computation definition in Python, - // expressed in a yet-to-be-defined syntax, might translate into definitions - // in a serialized form as shown below. - // - // @tff.computation - // def fed_eval(model): - // - // @tfe.defun - // def local_eval(model): - // ... - // return {'loss': ..., 'accuracy': ...} - // - // client_model = tff.federated_broadcast(model) - // client_metrics = tff.federated_map(local_eval, client_model) - // return tff.federated_mean(client_metrics) - // - // - // fed_eval = Computation(lambda=Lambda( - // parameter_name='model', - // result=Computation(block=Block( - // local=[ - // Block.Local(name='local_eval', value=Computation( - // tensorflow=TensorFlow(...))), - // Block.Local(name='client_model', value=Computation( - // call=Call( - // function=Computation( - // intrinsic=Intrinsic(uri='federated_broadcast')), - // argument=Computation( - // reference=Reference(name='model'))))), - // Block.Local(name='client_metrics', value=Computation( - // call=Call( - // function=Computation( - // intrinsic=Intrinsic(uri='federated_map')), - // argument=Computation( - // struct=Struct(element=[ - // Struct.Element( - // value=Computation( - // reference=Reference( - // name='local_eval'))), - // Struct.Element( - // value=Computation( - // reference=Reference( - // name='local_client_model'))) - // ])))))], - // result=Computation( - // call=Call( - // function=Computation( - // intrinsic=Intrinsic(uri='federated_mean')), - // argument=Computation( - // reference=Reference(name='client_metrics')))))))) - // - oneof computation { - // NON-COMPOSITIONAL CONSTRUCTS. - // - // The following constructs are the basic building blocks that can be - // composed into larger computations with the use of the compositional - // constructs defined below. - - // TensorFlow computation. TensorFlow computations have functional type - // signatures that cannot contain FederatedTypes, as they execute locally. - // In order to construct a TensorFlow computation that maps a federated - // value pointwise, one must use a federated map intrinsic (to be defined). - TensorFlow tensorflow = 2; - - // A built-in federated communication operator such as broadcast, federated - // sum, etc., or one of the custom operators added to the framework, and - // recognized by the compiler pipeline. Intrinsics have functional types, - // and most are defined as templates that can operate on abstract types, - // and/or federated values with arbitrary placements. - Intrinsic intrinsic = 3; - - // An external source of data to be used by a computation. - Data data = 10; - - // COMPOSITIONAL CONSTRUCTS. - // - // The following constructs can be used to combine simpler computations - // into more complex ones. For example, they can be used to express the - // top-level orchestration logic of a federated computation that combines - // blocks of client-side and server-side TensorFlow code with federated - // communication operators such as federated aggregation or broadcast. - - // A lambda expression is the primary means of defining new parameterized - // computations. lambdas always have functional types. - Lambda lambda = 4; - - // A block of computation logic, i.e., a series of expressions that refer - // to one-another. This mechanism is intended as a primary means of - // breaking down longer sequences of processing into simpler parts. A block - // can have a functional or a non-functional type (matching the type of its - // result), as it is primarily a mechanism for organizing code. - Block block = 5; - - // A reference to a name defined in a surrounding context, such as a Lambda - // or a Block, with the usual scoping rules (the name refers to the - // innermost scope in which it is defined). Always matches the type of the - // the parameter it references, i.e., T if the type of the lambda is T->T'. - // For example, in a Lambda "x : int -> foo(x)", which associates locally - // name "x" with its parameter, the reference to "x" will be of type "int", - // just as the parameter of the lambda in which the name "x" is defined. - Reference reference = 6; - - // A function call is the primary means of using lambdas, TensorFlow blocks, - // and other types of functional constructs to compute a specific result in - // a concrete context. The type of the call is the same as the type of the - // result of the function being called, i.e., a call with parameter of type - // T to a function of type T -> T' has type T'. - Call call = 7; - - // A struct is explicitly constructed from individual member values. - Struct struct = 8; - - // A selection by name or index from the result of another expression that - // returns a Struct. The type of the selection matches the type of the - // Struct element being selected (known statically, as the name or index is - // known statically, rather than computed). - // - // Note: In higher layers of the API, we will offer convenience mechanisms - // such as selection from a federated type. For example, if "x" is of a - // federated type "{}@clients", we will allow notation - // such as "x.foo" as a convenient shortcut for a pointwise selection that - // might be written as "federated_map(x, y->y.foo)" in a more complete form - // even though "x" is technically not a Struct. Here at the level of - // the Computation proto, however, we will represent computations in their - // fully fleshed-out form, with map and other implicit operators already - // injected at construction time by the framework as it translates a Python - // source cosde that defines a computation into this serialized form. - Selection selection = 9; - - // A placement literal. - Placement placement = 11; - - // A value literal. - Literal literal = 14; - - // EXPERIMENTAL CONSTRUCTS. - // - // The following constructs are currently considered experimental, and are - // not formally supported yet. They may be recognized, partially or fully - // handled by parts of the TFF codebase, but at this stage, one should not - // depend on this being the case. Their exact representation may continue - // to evolve, or they may be removed altogether. When these constructs are - // ready for public consumption and come with an explicit support, we will - // remove the "experimental" designation. - - // A local (non-federated) computation expressed in XLA. XLA computations - // have functional type signatures, and are used in a manner similar to - // local computations expressed in TensorFlow. However, not all types of - // TensorFlow computations are expressible in XLA (see below for the list - // of current limitations). - Xla xla = 12; - } - - // Reserving: The field is deleted. - reserved 13; - - // NEXT ID: 14 -} - -// A generic representation of an arbitrary type, defined as a variant over a -// number of primitive and compound types that can be nested. Note that not all -// nestings expressible with this structure may be valid, e.g., it may not make -// sense to declare a sequence of functions, or a federated type in which -// individual member values are themselves federated. However, rather than -// constraining the set of possible nestings at the syntactic level, which would -// increase boilerplate and could prove limiting in the future, we keep this -// variant structure simple, and we let the set of all valid type nestings be -// determined by the set of the currently supported operators. The current -// limitations on nesting are as follows: -// - FederatedType and FunctionType cannot be nested within a FederatedType or -// within a SequenceType. Currently, these may only be nested within a -// StructType. -// - A SequenceType currently cannot be nested within another SequenceType. -message Type { - oneof type { - FunctionType function = 1; - StructType struct = 2; - SequenceType sequence = 3; - TensorType tensor = 4; - AbstractType abstract = 5; - PlacementType placement = 6; - FederatedType federated = 7; - } -} - -// A representation of a functional type. Functions must have at most a single -// parameter and a single result. Multiple parameters or results to be modeled -// as compound types (e.g., as Structs). Note that since functions accept -// generic types, one can declare functions as parameters or results of other -// functions. We may not support functions as first-class values directly in -// the API surface, but the ability to express this is useful in defining type -// signatures for federated communication operators, and to support various -// types of extensibility. -// Concise syntax for examples of functional types: "T -> T'", where T, T' are -// the types of parameter and result, respectively. -message FunctionType { - Type parameter = 1; - Type result = 2; -} - -// A representation of a type of a struct. A struct is a compound type -// based on a similar type in Python that defines a finite set of named members, -// the types of which are known statically, that are arranged in a prescribed -// order and can be referred to by their position within the Struct. Note that -// besides structs, this abstract type can also be used to represent dicts, -// OrderedDicts, and regular tuples in Python. -// Concise syntax for examples of struct types: "T_i" or "name_i=T_i" separated -// by commas and optionally enclosed in "<>" (e.g., ""), -// where name_i is the optional name, and T_i is the type of i-th element. -message StructType { - repeated Element element = 1; - message Element { - string name = 1; - Type value = 2; - } -} - -// A representation of a type of a sequence. A sequence is a data structure -// that contains multiple elements of the same type that can be accessed only -// in a sequential manner, i.e., through an iterator. For now, we assume that -// a sequence can only be consumed once, i.e., there's no concept of iterator -// reset, as this facilitates high-performance implementations. We may add a -// notion of resettability in the future by introducing additional fields here -// while keeping non-resettability of sequences as the default. -// Concise syntax for examples of sequence types: "T*", where T is the type of -// elements. -message SequenceType { - Type element = 1; -} - -// A representation of a type of a single tensor in TensorFlow. Aspects such -// as sparseness are not intended to be represented at this level. -// Concise syntax for examples of tensor types: "dtype[shape]" or "dtype" for -// scalars, e.g., "bool[10]". -message TensorType { - // The data type of the tensor. - DataType dtype = 1; - - // The sizes of each dimension of the tensor. - // - // Undefined dimensions are allowed and represented by -1. Defined and - // undefined dimensions are to be considered distinct for type checking - // purposes. - repeated int64 dims = 2; - - // True iff the number of dimensions is unknown. - // - // If `dims` is unset: - // - `unknown_rank` == True corresponds to None - // - `unknown_rank` == False corresponds to [] - bool unknown_rank = 3; -} - -// A representation of an abstract type identified by a string label (analogous -// to "typename T" in C++, with "T" being the label). All occurrences of an -// abstract type with the same label within a type signature are interpreted as -// referring to the same concrete type. Abstract types can thus be used to -// represent templates similar to templates in C++. The label does not have any -// specific meaning otherwise. Any bijective renaming of all labels within a -// type signature is semantically a no-op (i.e., the resulting type definition -// is semantically identical to the original before renaming). The label may be -// modified by the compiler (e.g., due to naming conflicts). -// An AbstractType T might be used, for example, to define a signature of a -// generic aggregation operator as "federated_sum: {T}@clients -> T@server". -// Concise syntax for examples of abstract types: variations of uppercase "T", -// e.g., as in "T -> T'". -message AbstractType { - // The label used to refer to this abstract type within a type signature. - string label = 1; -} - -// The term `placement` refers to a representation of an instance of a built-in -// opaque type that conceptually represents a (membership of some) collective -// of participants in a distributed system that may participate in some part of -// a federated computation. -// -// In a typical federated computation, there would typically be at least one -// group of client devices, one or more groups of intermediate aggregators in a -// multi-tiered server architecture, and a central coordinator (perhaps a -// singleton group). With each of these groups, one would associate a separate -// placement (a separate instance of the built-in "placement" type). -// -// Placements are intended to be passed as arguments to some of the federated -// communication operators to determine the group of participants involved in -// the underlying federated communication protocol. -// -// In addition, placements can be used to define federated types (see below), -// i.e., types, values of which are hosted by members of a given collective, and -// thus potentially distributed across multiple locations. In a fully-specified -// federated computation, each concrete value (e.g., tensor) would typically -// have an associated concrete placement value to indicate which group of system -// participants (clients, aggregator or coordinator instances, etc.) it is -// hosted on. -// -// While placement is a first-class type, instances of which may be passed as -// parameters or returned as results, it is not equivalent to a simple vector -// of device addresses. A computation cannot list, add, remove, or test for -// existence of a particular device in a placement, as membership could be -// determined or influenced by factors outside of the programmer's control. For -// example, the membership of the collective of client devices represented by -// a "client" placement will depend on which devices choose to join the system -// and further influenced by factors such as failures and network delays. In -// most types of environments, the membership of a given group of participants -// could be dynamically evolving over time. Federated computations are defined -// at a higher level of abstraction that does not involve dealing with the -// identities of the individual devices. - -// A specification of a placement in a federated type. There are two ways of -// specifying a placement in this context that correspond to the two fields in -// the oneof below. Placement labels are used to construct template types of -// federated communication operators that can be applied to federated values. -// They relate all the identically-labeled placements that appear in the type -// signature without prescribing what specifically those placements must be. -// For example, consider the type signature below: -// -// federated_broadcast: T, p: placement -> T@p -// -// Here, "p" is a placement label, the role of which is simply to link the left -// and right sides of the type signature. The represenation of this type -// signature will use PlacementLabel on the left side. -// -// Concrete placement values are essentially placement literals, same as those -// that might appear in a computation body. They are used to bind types to -// specific placements with definite global meaning in a -// particular type of runtime environment. -message PlacementSpec { - oneof placement { - PlacementLabel label = 1; - Placement value = 2; - } -} - -// A representation of an abstract placement identified by a string label. -// All occurrences of this abstract placement label within a type signature are -// interpreted as referring to the same specific placement, similarly to how -// this is done for abstract type labels (except that equality of placement -// labels indicates equality of values, not just types). The abstract placement -// label does not have any specific meaning otherwise, and it is not intended to -// be compared with anything other than another abstract placement label -// contained within the same type signature. A bijective renaming of all -// abstract placement labels contained in a type signature is a semantic no-op. -// The label may be modified by the compiler (e.g., due to name conflicts). -message PlacementLabel { - // The label used to refer to this specific placement within a type signature. - string label = 1; -} - -// A representation of a specific placement defined globally by the runtime -// environment, and embedded as a literal of the "placement" type within a type -// signature or a computation definition. Unlike the abstract placement labels, -// the URIs in these placement values have a definite global meaning for all -// computations executed within the same environment. The exact set of global -// placement URIs and their meaning will depend on the system architecture and -// the capabilities of the platform. For example, in a production setting, these -// might include dedicated URIs to represent clients, intermediate aggregators, -// and coordinator placements. -message Placement { - // The globally unique URI that defines a specific global placement instance. - // For example, an URI might represent the global collective of all mobile - // devices running a certain app, or it might represent the specific - // well-known address of a central coordinator. The exact naming schemes and - // interpretation of these URIs is TBD, and will be documented later. - string uri = 1; -} - -// A representation of a federated type, i.e., one in which member components of -// the federated value are hosted on a collective of devices in a distributed -// system (where in some cases, that collective may be a singleton). As noted -// above in the comment on "PlacementType", examples of such collectives could -// include client devices, intermediate aggregators, central coordinator, etc., -// with one or more participants. Note that a federated type is a dependent -// type, as the placement label or value contained herein binds it to a specific -// placement, either one that's defined globally, or one that's supplied as a -// parameter and defined in another part of a computation's type signature. -// Concise syntax for federated types: "T@p" or "{T}@p" when "all_equal" is True -// or False, respectively, where "T" is the type of members, and "p" is either -// a placement label or a placement value (generally clear from context). -message FederatedType { - // A specification of placement that identifies the collective of participants - // in a distributed system on which member components of this federated value - // are hosted. - // - // If the federated type appears as a part of a functional type signature, - // this placement will generally be defined using a PlacementLabel to bind it - // to the type of the parameter, e.g., as below: - // - // federated_broadcast: T, p: placement -> T@p - // - // In the above "T@p" is a federated type, with label "p" (represented in the - // type as a PlacementLabel) simply serving as a reference to the parameter - // on the left. - // - // On the other hand, if a federated type appears on its own, not tied to the - // placement of any function parameter, the placement specified here will be - // a concrete placement literal (represented by a PlacementValue). - PlacementSpec placement = 1; - - // A bit that, if set, indicates that the member components of the federated - // value are all equal (if not set, member components may vary). This - // distinction is only meaningful for placements that represent collectives, - // such as clients or intermediate aggregators. For placements that represent - // centralized components (such as a central coordinator), this property is - // trivially satisfied (and still documented by setting this bit to True). - bool all_equal = 2; - - // The type of the local member components of the federated value, i.e., the - // components that are locally hosted on each individual participant (member - // of the collective determined by the "placement" above). - Type member = 3; -} - -// A representation of the type of placements (see the discussion above by the -// definition of the Placement message that represents instances of this type). -// This message is only used in situations, where placement is passed as a -// first-class value (e.g., in the argument to broadcast). The specfications of -// federated types only refer to specific placements (see Placement above). -// Note that there is only a single primitive "placement" type. The embedded -// field "instance_label" does not qualify the type and does not affect type -// equality. It is only used to annotate the instance of this type as it appears -// in a type signature in order to form dependent types. -// Concise syntax for the placement type: "placement" for the type itself, and -// "p: placement" to annotate the specific entry in the type signature with the -// label "p". -message PlacementType { - // An optional label that can be used to refer to the specific instance of the - // "placement" type represented by this entry in the type signature. If this - // field is present in the PlacementType message, generally as a parameter in - // a functional type signature, the label is associated with the specific - // placement value supplied in that parameter, which allows it to be used to - // specify a federated type hosted by the collective of participants - // represented by this placement. For example, consider this type signature: - // - // federated_broadcast: T, p: placement -> T@p - // - // The type specification of the 2nd element of the broadcast argument Struct - // would be PlacementType(instance_label=PlacementLabel(label='p')). Here, the - // type of the second element is still simply "placement"; as noted above, - // there is only one such built-in type to represent all sorts of collectives. - // The presence of the label only associates 'p' with the value of the second - // element of the parameter Struct. On the right side, the pecification of the - // federated result type contains Placement(label=PlacementLabel(label='p')), - // thus binding the placement of the result to the value in the argument. When - // comparing types, the presence of this label is ignored. - PlacementLabel instance_label = 1; -} - -// A representation of a section of TensorFlow code. -// -// The type signature associated with this type of computation must be defined -// only in terms of tensors, structs, and sequences. Sequences cannot be nested. -// -// At the moment, we only allow sequences as a parameters (note that pointwise -// transformations of sequences can still be expressed using a map intrinsic). -// This restriction may be relaxed in the future when support for handling data -// sets as first-class objects in TensorFlow evolves. -// -// Note that unlike in polymorphic functions created by tf.defuns, the chosen -// representation requires all type signatures, including those of individual -// elements of a sequence, to be fully specified. In case of sequences, the -// structure of their elements is effectively encoded in the parts of the graph -// that constitute the serialized representation of tf.data.Datasets and -// iterators. -// -// While we will offer support for writing polymorphic TensorFlow logic, types -// will be captured automatically and made concrete based on usage at the Python -// level of the API. Users of TFF will not need to declare them explicitly, but -// template specialization will happen before computation logic gets serialized. -// -// Next id: 8 -message TensorFlow { - // The semantics is as follows: the graph embedded here will be instantiated, - // with all placeholder components of the parameter bound to concrete tensors - // or, in case of sequences, to iterators associated with concrete datasets. - // The compomnents of the result will then all be simultaneously evaluated in - // what corresponds to a single Session.run() in non-eager mode. - - // Note: Currently, there is no way to represent any higher-level scripting - // over the graph. We require that all control flow logic be expressed using - // control dependencies and other TensorFlow constructs and triggered by the - // evaluation of outputs within a single Session.run(), as postulated above. - // Depending on how restrictive this turns out to be we might, or might not, - // add a script that describes a sequence of Session.run() calls, one-off or - // repeated in a loop, as an optional component in a TensorFlow computation, - // to address the impedance mismatch between push- and pull-based styles of - // processing supported by various parts of the target execution environment. - - // A serialized representation of a TensorFlow graph to execute. - // - // Stores a tensorflow.GraphDef message. - // Note: This representation may evolve, e.g., get replaced with a MetaGraph, - // SavedModel, or a similar structure. Dependencies on the exact form of the - // graph encoding used here should be kept to minimum, and proxied by wrapper - // libraries for composing computations in python/core/impl/. - // - // TODO: b/117428091 - Update this representation based on the emerging TF 2.0 - // serialization standards as needed if/when they meet the constraints of the - // target production environments, and provided that they don't introduce - // additional complexity. - google.protobuf.Any graph_def = 1; - - // String name of an initialization op to run on the graph before fetching - // results. This op is intended only to be used for running tf.Variable - // initializers. - string initialize_op = 4; - - // String name of a tensor which may be fed a unique identifier token for the - // current session. This allows TensorFlow custom ops to refer to - // session-global values created by the runner of the current session. - string session_token_tensor_name = 6; - - // A pair of bindings for the parameter and the result. The parameter binding - // can be omitted if the computation does not declare a parameter. The result - // binding is mandatory, as all TensorFlow computations must declare results. - Binding parameter = 2; - Binding result = 3; - - // A general representation of a binding of either a parameter or a result to - // a part of the embedded TensorFlow graph. Note that the structure of the - // binding is nested, and parallels the structure of the corresponding part of - // the type signature. - message Binding { - oneof binding { - // A binding associated with a struct in the type signature. Specifies an - // individual binding for each element of the struct. - StructBinding struct = 1; - - // A binding associated with a (logical) tensor in the type signature. - // Associates that tensor to one or more (concrete) tensors in the graph. - TensorBinding tensor = 2; - - // A binding associated with a sequence. Associates the sequence with a - // part of the TensorFlow graph that will represent a data set iterator, - // next element, or an equivalent iterator-like structure. - SequenceBinding sequence = 3; - } - } - - // A binding of a Struct declared in the type signature to parts of the - // embedded TensorFlow graph. - message StructBinding { - // Bindings for elements of the Struct. The number of elements in this field - // must be equal to the number of Struct elements declared in the type - // signature, with the k-th binding declared here corresponding to the k-th - // Struct element in the type signature. The element names are omitted since - // they are redundant (correspondence is established by element order). - repeated Binding element = 1; - } - - // A representation of a single tensor declared in the type signature in the - // serialized graph representation embedded here. - message TensorBinding { - oneof binding { - // The name of a dense tensor in a TensorFlow graph that corresponds to a - // single tensor component in the type signature. - string tensor_name = 1; - - // Note: This structure may eventually be extended with non-dense tensor - // encodings, such as .tensorflow.TensorInfo.CooSparse. - } - } - - // A representation of a sequence declared in the type signature. - message SequenceBinding { - // Previously was `iterator_string_handle_name`, but now only - // `variant_tensor_name` is supported. - reserved 1; - - oneof binding { - // The name of the variant tensor that represents the data set created - // using `tf.data.experimental.from_variant`. - string variant_tensor_name = 2; - - // The name of the string tensor that represents the data set created - // using `tf.raw_ops.DatasetFromGraph`. - string graph_def_tensor_name = 3; - - // Note: This structure will likely evolve and get extended with other - // means of encoding data sets in the serialized graph representation. - } - } - - // An optional id that can be used to identify identical TensorFlow messages - // without having to compare the (potentially large) `graph_def` fields. - // - // This field is not intended to be set during comptuation - // construction/tracing. Rather, it is designed as a final compilation pass - // that allows execution stacks to "cache" the graphs across invoke calls, - // avoiding costly graph parsing every invocation. - // - // The id is NOT required to be unique across machines, meaning two machines - // producing the same graph_def may have the same ids. If these machines - // should not be talking to the same execution stack. - // - // NOTE: the default value of 0 has the same meaning as having the field - // unset, and having no id. Any code setting this value should exclude zero. - message CacheKey { - uint64 id = 1; - } - CacheKey cache_key = 5; - - reserved 7; // LayoutMap layout_map -} - -// A representation of an intrinsic function. Intrinsics are functions that are -// known to the framework, and uniquely identified by a URI. This includes both -// the standard federated communication operators, such as, e.g., broadcast, -// federated sum, secure aggregation, and custom operators that might be added -// by the user to the pipeline. The compiler recognizes the intrinsics, and -// replaces them with a suitable implementation. Intrinsics may be both generic -// and specialized, low- and high-level. The exact naming scheme used to -// identify them, and how it can be extended to support new operators defined by -// external contributors, will be described elsewhere. -message Intrinsic { - // The URI that uniquely identifies the intrinsic within the set of operators - // built into the framework. - string uri = 1; -} - -// A representation of a parameterized computation defined as a lambda -// expression that consists of a single parameter name, and an expression that -// contains references to this parameter name (the "name" computation variant). -// Lambdas can be nested, e.g., the result can also be a lambda or contain a -// lambda. Inner lambdas are allowed to refer to the parameter defined in the -// outer lambdas. We assume the usual rules of name hiding: inner names obscure -// the outer names. -// -// Concise syntax for lambdas: "parameter_name -> comp" where "comp" represents -// a parameterized computation that produces the result, or in the more general -// form "parameter_name : T -> comp" to indicate that parameter is of type "T". -// For example, a lambda that takes a 2-Struct of an unary operator and an -// integer as input, and returns the result of calling the unary operator -// on the integer, can be written as "x -> x[0](x[1])", or in the full form with -// type annotation as "x: <(int->int), arg=int> -> x[0](x[1])". -message Lambda { - // The name to use internally within this lambda to refer to the parameter. - // The parameter is mandatory. The name defined here can be used internally - // anywhere in the result computation, except if overridden in a nested - // lambda, where it can be hidden by a parameter with a conflicting name. - string parameter_name = 1; - - // A computation that represents the result of applying the lambda to the - // parameter. The result may (almost always will) contain references to the - // parameter defined above. - Computation result = 2; - - // Note that a Lambda as a whole must have a functional type T -> T', where - // T' is the type of the result, and T is the type of all references to the - // parameter within the result. -} - -// A representation of a body of computation logic broken down into a sequence -// of local definitions that gradually build up towards a single final result -// expression. A block defines a sequence of local names, each associated with -// a computation. Computations associated with names introduced later can -// refer to names introduced earlier. At the end of a block is a single result -// computation defined in terms of those locals. It is similar to LET* in LISP. -// -// The intended usage of this abstraction is to break down complex processing -// into simpler, smaller, easier to understand units that are easier to work -// with in this broken-down representation, as opposed to a single monolithic -// complex expression. We expect it to be used, e.g., to represent top-level -// federated orchestration logic. -// -// A block is technically a redundant abstraction, as it can be equivalently -// represented using lambda expressions. For example, a simple block of the -// form "let x=y in z" is equivalent to "(x->z)(y)". Larger blocks can likewise -// be represented similarly as nested lambdas. The main purpose of introducing -// this abstraction is simplicity. While expressible via lambdas, a sequential -// representation is preferred over nested lambdas as it is more readable and -// easier to debug, and more closely matches how code is expected to be executed -// by a runtime environment, in which higher-order functions may be unsupported. -// -// One way to think of blocks is as a generalization of a GraphDef, and such, -// a mechanism for constructing data flow graphs that can include TensorFlow -// blocks and various federated communication operators as processing nodes. -// Indeed, this is the primary intended usage of blocks. In this interpretation -// a block can be thought of as a direct acyclic graph, with the locals and -// the result being the graph "nodes". Locals represent various partial results -// computed along the way, and the result is the "op" that represents the -// output. Each node has associated with it an expression (computation) that -// specifies how to derive its value from the values represented by other nodes -// referenced by name. The presence of such reference to one node's name inside -// another node's expression (computation) can be interpreted as a dependency -// edge in a data flow graph. Indeed, the data flow interpretation corresponds -// to the manner in which processing is expected to flow. -// -// Concise syntax: "let name_1=comp_1, ...., name_k=comp_k in comp" with -// "name_k" and "comp_k" representing the names of the locals, and computations -// that compute the results that those names represent. For example, a complex -// expression "x[0](x[1])" can be represented in a slightly more expanded -// form as "let f=x[0], v=[1] in f(v)". -message Block { - // One or more locals defined within the block, each associating a name with a - // computation. Computations, whether those associated with the locals, or - // that associated with the result below, can contain references to names - // defined earlier, here or in the surrounding context. Self-references are - // prohibited. All names introduced here must be different. Since execution - // semantics at this level is purely functional without side effects, the - // ordering in which the locals are declared is not significant, as it is only - // the dependencies between the computations that effectively determine the - // causal relationships that constrain the order of execution. - // - // Blocks can be nested, just as lambdas, and the same name scoping rules - // apply, i.e., blocks (or lambdas) contained within an embedded computation, - // whether in a local or in the result, are allowed to refer to names defined - // in an outer lambda or block (unless obscured by a nested declaration). - // If names defined in the outer context conflict with those defined in the - // inner congtext (here), the inner names hide outer names in the context in - // which they are defined. Thus, for example, in "x -> let x=1, y=x+1 in y", - // the "x=1" would hide the lambda parameter, and therefore "y=x+1" would - // refer to the inner "x". - repeated Local local = 1; - message Local { - string name = 1; - Computation value = 2; - } - - // The result computation. Always required. The computation typically refers - // to locals defined above by name, just like the result in a lambda. - Computation result = 2; -} - -// A reference to a computation defined as a local in a block, or to the -// parameter of a lambda. -message Reference { - string name = 1; -} - -// A representation of a function call. -// -// The concise notation for function calls is "f(x)" or "f()", where "f" is the -// function, and "x" is the optional argument. -message Call { - // A computation that represents the function to call. The value that this - // represents must be of a functional type. - Computation function = 1; - - // A computation that represents the argument to the function specified above. - // Present if and only if "function" declares a parameter. Must match the - // function's parameter type (i.e., the function's parameter type must be - // assignable from the argument type). - Computation argument = 2; -} - -// A representation of a Struct constructor. -// -// The concise representation of a Struct constructor is "<>"-enclosed and -// comma-separated list of value or "name=value" sections, for example "<1,2>" -// or "". -message Struct { - // The ordering of Struct elements is significant, and determines the type of - // the value represented by the expression. The names are optional. - repeated Element element = 1; - message Element { - string name = 1; - Computation value = 2; - } -} - -// A representation of a value selected from a Struct returned by another -// computation. -// -// The concise representation of a selection is "x[index]" for positional -// selection, and "x.name" for name-based selection, where "x" represents the -// source from which to select. For example, in lambda "x -> x[0](x[1])", where -// "x[0]" and "x[1]" both represent selections of named members from the STruct -// "x", respectively. -message Selection { - // The source of selection, always required. This is a computation that - // returns a Struct (possibly nested), from which to select an element - // by name or by index. - Computation source = 1; - - // A specification of what to select from the context (Struct). Indexes, - // when applied to Structs, are 0-based, i.e., "[0]" selects the first - // element. - int32 index = 3; -} - -// A specification of an external source of data to be used by a computation. -// -// Data streams are curently expected to always be nested structures composed -// of sequences, Structs, and tensor types. Sequences cannot be nested. -// Structs may appear at the outer level (to return multiple sequences, -// e.g., training and testing samples), or at the element level (if sequences -// contain structured elements, e.g., examples already parsed into individual -// features). -// -// Although data could conceivably be modeled via intrinsics, we factor it out -// to more conveniently express various types of input pipelines without having -// to pack everything into a URI. Sources of data could include training -// examples emitted by a mobile app, files on a filesystem, data to obtain from -// a location on the web, etc., and the specification, in addition to the -// origin of the data, could include things like example selection criteria, -// data decoding or simple transformations. For now, this structure is a -// specification to be interpreted by the runtime environment. To be extended -// as needed. -message Data { - oneof data { - // A specification of the data stream as a URI to be interpreted by the - // environment. - string uri = 1 [deprecated = true]; - - // A specification of the data stream to be interpreted by the environment. - google.protobuf.Any content = 3; - } - reserved 2; -} - -// A representation of a section of XLA code (experimental-only). -// -// The type signature associated with this type of computation must be defined -// as a function which accepts and returns tensors and potentially nested -// structures of tensors. -message Xla { - // A serialized representation of XLA code to execute. - // - // Stores an `HloModuleProto` message, as defined in the TensorFlow repo in - // the file "tensorflow/compiler/xla/service/hlo.proto" in the main branch. - // - // It is recommended, albeit not required that the entry computation in this - // module accepts its parameters as a single tuple. - // - // NOTE: As it is experimental-only, this representation may evolve, possibly - // in a manner that is backwards-incompatible. Make sure not to depend on the - // current form of this representation, and not to persist it in places where - // subsequent changes could cause breakages. - google.protobuf.Any hlo_module = 1; - - // A pair of bindings for the parameter and the result. The parameter binding - // can be omitted if the computation does not declare a parameter. The result - // binding is mandatory, as all XLA computations must declare results. - Binding parameter = 2; - Binding result = 3; - - // A general representation of a binding of either a parameter or a result to - // a part of the embedded HLO module. Note that the structure of the binding - // is nested, and it parallels the structure of the corresponding part of the - // type signature. - message Binding { - oneof binding { - StructBinding struct = 1; - TensorBinding tensor = 2; - } - } - - // A binding associated with a struct in the type signature. Specifies an - // individual binding for each element of the struct. - message StructBinding { - // Bindings for elements of the struct. The number of elements in this field - // must be equal to the number of struct elements declared in the type - // signature, with the k-th binding declared here corresponding to the k-th - // struct element in the type signature. The element names are omitted since - // they are redundant (correspondence is established by element order). - repeated Binding element = 1; - } - - // A binding associated with a (logical) tensor in the type signature. - // Associates that tensor to one or more (concrete) tensors in the inputs - // or outputs of a computation in the module. - message TensorBinding { - oneof binding { - // The 0-based index of this tensor in (the flattened form of) either the - // parameter or result tuple for the entry computation of the HLO module, - // i.e., the `HloComputationProto` with the id matching the module's - // `entry_computation_id`. - // - // The order of indexes associated with the result tensors is defined by - // the order in which tensors appear in the DFS traversal of the root - // instruction in the computation (which can be a tensor, or a possibly - // recursively nested tuple). For example, if the XLA computation returns - // a nested tuple ((int32, int32), int32), the indexes of the tensors in - // the result are ((0, 1), 2), accordingly. - // - // The order of indexes for parameter tensors is defined likewise. In the - // case of multiple arguments, tensor indexes are determined by traversing - // arguments in the order in which they appear on the parameter list (the - // order of `parameter_number` in the `HloInstructionProto`s. - // For example, if the computation takes 2 arguments, the first of which - // is a 2-tuple of tensors, and the second of which is a tensor, the - // indexes identifying the individual portions of the argument list would - // be (0, 1), 2, i.e., 0 would refer to the first tuple element of the - // first parameter, etc. - int32 index = 1; - } - } -} - -// A representation of a literal value. -// -// The type signature associated with this type of computation must be defined -// as a tensor. -message Literal { - Array value = 1; -} diff --git a/tensorflow_federated/proto/v0/data_type.proto b/tensorflow_federated/proto/v0/data_type.proto deleted file mode 100644 index 14b902a744..0000000000 --- a/tensorflow_federated/proto/v0/data_type.proto +++ /dev/null @@ -1,31 +0,0 @@ -syntax = "proto3"; - -package tensorflow_federated.v0; - -enum DataType { - // Sorted by first kind (bool, int, uint, float, complex, string), then by - // wether the dtype exists natively in numpy, and finally by bit width. - DT_INVALID = 0; - DT_BOOL = 10; - DT_INT8 = 6; - DT_INT16 = 5; - DT_INT32 = 3; - DT_INT64 = 9; - DT_UINT8 = 4; - DT_UINT16 = 17; - DT_UINT32 = 22; - DT_UINT64 = 23; - DT_HALF = 19; - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_COMPLEX64 = 8; - DT_COMPLEX128 = 18; - DT_BFLOAT16 = 14; - DT_STRING = 7; - - reserved 11; // DT_QINT8 - reserved 12; // DT_QUINT8 - reserved 13; // DT_QINT32 - reserved 15; // DT_QINT16 - reserved 16; // DT_QUINT16 -} diff --git a/tensorflow_federated/proto/v0/executor.proto b/tensorflow_federated/proto/v0/executor.proto index 313b0e1da0..b2b6f31689 100644 --- a/tensorflow_federated/proto/v0/executor.proto +++ b/tensorflow_federated/proto/v0/executor.proto @@ -2,8 +2,8 @@ syntax = "proto3"; package tensorflow_federated.v0; -import "tensorflow_federated/proto/v0/array.proto"; -import "tensorflow_federated/proto/v0/computation.proto"; +import "federated_language/proto/array.proto"; +import "federated_language/proto/computation.proto"; // A service providing computation execution. service ExecutorGroup { @@ -44,7 +44,7 @@ service ExecutorGroup { } message Cardinality { - tensorflow_federated.v0.Placement placement = 1; + federated_language.Placement placement = 1; int32 cardinality = 2; } @@ -148,11 +148,11 @@ message Value { message Sequence { // The TensorFlow Federated `Type` of the elements in this // sequence. - tensorflow_federated.v0.Type element_type = 2; + federated_language.Type element_type = 2; // A representation of a sequence of values. message Element { - repeated tensorflow_federated.v0.Array flat_value = 1; + repeated federated_language.Array flat_value = 1; } repeated Element element = 4; @@ -163,7 +163,7 @@ message Value { // A representation of a federated value. message Federated { // The type of the federated value. - tensorflow_federated.v0.FederatedType type = 1; + federated_language.FederatedType type = 1; // The member constituents, one per participant in the collective defined // by this value's placement within the executor. @@ -172,12 +172,12 @@ message Value { oneof value { // An array value. - tensorflow_federated.v0.Array array = 6; + federated_language.Array array = 6; // A serialized TFF computation; this is the canonical (and currently only) // way to pass any functional constructs, but the computation included here // does not necessarily have to be of a functional type. - tensorflow_federated.v0.Computation computation = 2; + federated_language.Computation computation = 2; // A struct of values. Struct struct = 3; diff --git a/tensorflow_federated/python/aggregators/BUILD b/tensorflow_federated/python/aggregators/BUILD index 5f4cccafb7..c363c27a61 100644 --- a/tensorflow_federated/python/aggregators/BUILD +++ b/tensorflow_federated/python/aggregators/BUILD @@ -49,12 +49,9 @@ py_library( ":factory", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -66,10 +63,9 @@ py_test( ":aggregator_test_utils", ":factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -81,14 +77,9 @@ py_library( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -100,11 +91,9 @@ py_test( ":mean", ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -116,13 +105,9 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -136,12 +121,9 @@ py_test( ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -154,12 +136,9 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -173,10 +152,9 @@ py_test( ":factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -187,13 +165,9 @@ py_library( ":factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -204,11 +178,9 @@ py_test( ":discretization", ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -227,14 +199,9 @@ py_library( ":secure", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -252,12 +219,10 @@ py_test( ":rotation", ":secure", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -269,13 +234,9 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -287,10 +248,9 @@ py_test( ":encoded", ":factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -298,8 +258,8 @@ py_library( name = "factory", srcs = ["factory.py"], deps = [ - "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/templates:aggregation_process", + "@federated_language//federated_language", ], ) @@ -309,10 +269,8 @@ py_library( deps = [ ":factory", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", + "@federated_language//federated_language", ], ) @@ -325,7 +283,7 @@ py_test( ":mean", ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -348,13 +306,9 @@ py_library( ":sum_factory", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -367,10 +321,9 @@ py_test( ":factory", ":mean", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -381,12 +334,9 @@ py_library( ":factory", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -399,9 +349,7 @@ py_test( ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -413,14 +361,9 @@ py_library( ":sum_factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -431,10 +374,9 @@ py_test( ":modular_clipping", ":sum_factory", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -446,10 +388,7 @@ py_library( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/federated_context:value_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -460,11 +399,7 @@ py_test( deps = [ ":primitives", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -477,11 +412,8 @@ py_library( ":sum_factory", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:estimation_process", + "@federated_language//federated_language", ], ) @@ -491,10 +423,8 @@ py_test( deps = [ ":quantile_estimation", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:estimation_process", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -506,15 +436,10 @@ py_library( ":sum_factory", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -528,14 +453,10 @@ py_test( ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -549,13 +470,9 @@ py_library( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -568,11 +485,9 @@ py_test( ":rotation", ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -584,15 +499,9 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_transformations", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -602,9 +511,7 @@ py_test( deps = [ ":sampling", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -618,14 +525,10 @@ py_library( "//tensorflow_federated/python/core/backends/mapreduce:intrinsics", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -638,14 +541,10 @@ py_test( ":secure", "//tensorflow_federated/python/core/backends/test:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -657,13 +556,9 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -677,12 +572,9 @@ py_test( ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -692,12 +584,9 @@ py_library( deps = [ ":factory", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -709,9 +598,8 @@ py_test( ":factory", ":sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/aggregators/aggregator_test_utils.py b/tensorflow_federated/python/aggregators/aggregator_test_utils.py index c479bc646f..01fab2743f 100644 --- a/tensorflow_federated/python/aggregators/aggregator_test_utils.py +++ b/tensorflow_federated/python/aggregators/aggregator_test_utils.py @@ -15,15 +15,12 @@ import typing +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -52,26 +49,28 @@ def create( type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): - state = intrinsics.federated_map( + state = federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: x + 1), state ) - result = intrinsics.federated_map( + result = federated_language.federated_map( tensorflow_computation.tf_computation( lambda x: tf.nest.map_structure(lambda y: y + 1, x) ), - intrinsics.federated_sum(value), + federated_language.federated_sum(value), ) - measurements = intrinsics.federated_value( - MEASUREMENT_CONSTANT, placements.SERVER + measurements = federated_language.federated_value( + MEASUREMENT_CONSTANT, federated_language.SERVER ) return measured_process.MeasuredProcessOutput(state, result, measurements) diff --git a/tensorflow_federated/python/aggregators/aggregator_test_utils_test.py b/tensorflow_federated/python/aggregators/aggregator_test_utils_test.py index 6fa5202030..f8a87c9d03 100644 --- a/tensorflow_federated/python/aggregators/aggregator_test_utils_test.py +++ b/tensorflow_federated/python/aggregators/aggregator_test_utils_test.py @@ -15,14 +15,13 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import aggregator_test_utils from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -37,24 +36,24 @@ class SumPlusOneFactoryComputationTest( def test_type_properties(self, value_type): sum_f = aggregator_test_utils.SumPlusOneFactory() self.assertIsInstance(sum_f, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType( - np.int32, placements.SERVER + expected_state_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - expected_measurements_type = computation_types.FederatedType( - np.int32, placements.SERVER + expected_measurements_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -63,7 +62,7 @@ def test_type_properties(self, value_type): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type ), @@ -78,10 +77,12 @@ def test_type_properties(self, value_type): @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_value_type_raises(self, bad_value_type): sum_f = aggregator_test_utils.SumPlusOneFactory() @@ -93,7 +94,7 @@ class SumPlusOneFactoryExecutionTest(tf.test.TestCase): def test_sum_scalar(self): sum_f = aggregator_test_utils.SumPlusOneFactory() - value_type = computation_types.to_type(np.float32) + value_type = federated_language.to_type(np.float32) process = sum_f.create(value_type) state = process.initialize() @@ -107,7 +108,7 @@ def test_sum_scalar(self): def test_sum_structure(self): sum_f = aggregator_test_utils.SumPlusOneFactory() - value_type = computation_types.to_type(((np.float32, (2,)), np.int32)) + value_type = federated_language.to_type(((np.float32, (2,)), np.int32)) process = sum_f.create(value_type) state = process.initialize() diff --git a/tensorflow_federated/python/aggregators/concat.py b/tensorflow_federated/python/aggregators/concat.py index e2956cee72..271dd89722 100644 --- a/tensorflow_federated/python/aggregators/concat.py +++ b/tensorflow_federated/python/aggregators/concat.py @@ -16,18 +16,13 @@ import functools from typing import TypeVar +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -56,13 +51,13 @@ def _next_fn_impl( state, value, concat_fn, unconcat_fn, inner_agg_process, weight=None ): """Implements the next_fn for concat_factory's resulting AggregationProcess.""" - concat_value = intrinsics.federated_map(concat_fn, value) + concat_value = federated_language.federated_map(concat_fn, value) if weight is None: inner_agg_output = inner_agg_process.next(state, concat_value) else: inner_agg_output = inner_agg_process.next(state, concat_value, weight) - unconcat_value = intrinsics.federated_map( + unconcat_value = federated_language.federated_map( unconcat_fn, inner_agg_output.result ) return measured_process.MeasuredProcessOutput( @@ -74,17 +69,20 @@ def _next_fn_impl( def create_concat_fns( value_type: factory.ValueType, -) -> tuple[computation_base.Computation, computation_base.Computation]: +) -> tuple[ + federated_language.framework.Computation, + federated_language.framework.Computation, +]: """Creates the forward and backward flattening/concatenation functions.""" # As the factory alters the tensor specs, we compute the Python structure # of the types for the unconcat procedure. if isinstance( - value_type, computation_types.StructWithPythonType - ) and type_analysis.is_structure_of_tensors(value_type): + value_type, federated_language.StructWithPythonType + ) and federated_language.framework.is_structure_of_tensors(value_type): original_structure = type_conversions.structure_from_tensor_type_tree( lambda x: tf.TensorSpec(x.shape, x.dtype), value_type ) - elif isinstance(value_type, computation_types.TensorType): + elif isinstance(value_type, federated_language.TensorType): original_structure = tf.TensorSpec(value_type.shape, value_type.dtype) else: raise TypeError( @@ -118,8 +116,8 @@ def _check_component_dtypes(value_type): # Restrict dtypes to integers and floats for now. if not ( - type_analysis.is_structure_of_floats(value_type) - or type_analysis.is_structure_of_integers(value_type) + federated_language.framework.is_structure_of_floats(value_type) + or federated_language.framework.is_structure_of_integers(value_type) ): raise TypeError( 'Components of `value_type` must all be integers or ' @@ -141,9 +139,11 @@ def create(self, value_type) -> aggregation_process.AggregationProcess: init_fn = inner_agg_process.initialize state_type = init_fn.type_signature.result - @federated_computation.federated_computation( + @federated_language.federated_computation( state_type, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): return _next_fn_impl( @@ -170,10 +170,14 @@ def create( ) init_fn = inner_agg_process.initialize - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), - computation_types.FederatedType(weight_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) def next_fn(state, value, weight): return _next_fn_impl( diff --git a/tensorflow_federated/python/aggregators/concat_test.py b/tensorflow_federated/python/aggregators/concat_test.py index e37b396b18..d8c9aac426 100644 --- a/tensorflow_federated/python/aggregators/concat_test.py +++ b/tensorflow_federated/python/aggregators/concat_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -22,9 +23,6 @@ from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -65,40 +63,42 @@ class ConcatFactoryComputationTest(tf.test.TestCase, parameterized.TestCase): ) def test_concat_type_properties_unweighted(self, value_type): factory = _concat_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. - server_state_type = computation_types.FederatedType((), placements.SERVER) + server_state_type = federated_language.FederatedType( + (), federated_language.SERVER + ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type ) # Inner SumFactory has no measurements. - expected_measurements_type = computation_types.FederatedType( - (), placements.SERVER + expected_measurements_type = federated_language.FederatedType( + (), federated_language.SERVER ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.next.type_signature, expected_next_type ) @@ -110,48 +110,48 @@ def test_concat_type_properties_unweighted(self, value_type): ) def test_clip_type_properties_weighted(self, value_type, weight_type): factory = _concat_mean() - value_type = computation_types.to_type(value_type) - weight_type = computation_types.to_type(weight_type) + value_type = federated_language.to_type(value_type) + weight_type = federated_language.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # State comes from the inner MeanFactory. - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict(value_sum_process=(), weight_sum_process=()), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type ) # Measurements come from the inner mean factory. - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict(mean_value=(), mean_weight=()), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), - weight=computation_types.FederatedType( - weight_type, placements.CLIENTS + weight=federated_language.FederatedType( + weight_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.next.type_signature, expected_next_type ) @@ -163,7 +163,7 @@ def test_clip_type_properties_weighted(self, value_type, weight_type): ) def test_raises_on_non_numeric_dtypes(self, value_type): factory = _concat_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'must all be integers or floats'): factory.create(value_type) @@ -178,20 +178,23 @@ def test_raises_on_non_numeric_dtypes(self, value_type): ) def test_raises_on_mixed_dtypes(self, value_type): factory = _concat_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'should have the same dtype'): factory.create(value_type) @parameterized.named_parameters( ('plain_struct', [('a', np.int32)]), - ('sequence', computation_types.SequenceType(np.int32)), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('nested_sequence', [[[computation_types.SequenceType(np.int32)]]]), - ('nested_function', [computation_types.FunctionType(np.int32, np.int32)]), + ('sequence', federated_language.SequenceType(np.int32)), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('nested_sequence', [[[federated_language.SequenceType(np.int32)]]]), + ( + 'nested_function', + [federated_language.FunctionType(np.int32, np.int32)], + ), ) def test_raises_on_bad_tff_value_types(self, value_type): factory = _concat_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): factory.create(value_type) @@ -216,7 +219,7 @@ class ConcatFactoryExecutionTest(tf.test.TestCase, parameterized.TestCase): ) def test_concat_sum(self, value_type, client_data, expected_sum): factory = _concat_sum() - process = factory.create(computation_types.to_type(value_type)) + process = factory.create(federated_language.to_type(value_type)) state = process.initialize() self.assertEqual(state, ()) @@ -255,8 +258,8 @@ def test_concat_mean( ): factory = _concat_mean() process = factory.create( - computation_types.to_type(value_type), - computation_types.to_type(np.float32), + federated_language.to_type(value_type), + federated_language.to_type(np.float32), ) expected_state = collections.OrderedDict( diff --git a/tensorflow_federated/python/aggregators/deterministic_discretization.py b/tensorflow_federated/python/aggregators/deterministic_discretization.py index 669a1f1573..2d9d0a35e5 100644 --- a/tensorflow_federated/python/aggregators/deterministic_discretization.py +++ b/tensorflow_federated/python/aggregators/deterministic_discretization.py @@ -16,6 +16,7 @@ import collections from typing import Optional +import federated_language import numpy as np import tensorflow as tf @@ -23,11 +24,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -96,11 +92,11 @@ def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: # Validate input args and value_type and parse out the TF dtypes. - if isinstance(value_type, computation_types.TensorType): + if isinstance(value_type, federated_language.TensorType): tf_dtype = value_type.dtype elif isinstance( - value_type, computation_types.StructWithPythonType - ) and type_analysis.is_structure_of_tensors(value_type): + value_type, federated_language.StructWithPythonType + ) and federated_language.framework.is_structure_of_tensors(value_type): tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type ) @@ -112,7 +108,7 @@ def create( ) # Check that all values are floats. - if not type_analysis.is_structure_of_floats(value_type): + if not federated_language.framework.is_structure_of_floats(value_type): raise TypeError( 'Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.' @@ -120,7 +116,7 @@ def create( if self._distortion_aggregation_factory is not None: distortion_aggregation_process = self._distortion_aggregation_factory.create( - computation_types.to_type(np.float32) # pytype: disable=wrong-arg-types + federated_language.to_type(np.float32) # pytype: disable=wrong-arg-types ) @tensorflow_computation.tf_computation(value_type, np.float32) @@ -152,32 +148,36 @@ def distortion_measurement_fn(value, step_size): discretize_fn.type_signature.result ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): state = collections.OrderedDict( - step_size=intrinsics.federated_value( - self._step_size, placements.SERVER + step_size=federated_language.federated_value( + self._step_size, federated_language.SERVER ), inner_agg_process=inner_agg_process.initialize(), ) - return intrinsics.federated_zip(state) + return federated_language.federated_zip(state) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): server_step_size = state['step_size'] - client_step_size = intrinsics.federated_broadcast(server_step_size) + client_step_size = federated_language.federated_broadcast( + server_step_size + ) - discretized_value = intrinsics.federated_map( + discretized_value = federated_language.federated_map( discretize_fn, (value, client_step_size) ) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) - undiscretized_agg_value = intrinsics.federated_map( + undiscretized_agg_value = federated_language.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size) ) @@ -189,7 +189,7 @@ def next_fn(state, value): ) if self._distortion_aggregation_factory is not None: - distortions = intrinsics.federated_map( + distortions = federated_language.federated_map( distortion_measurement_fn, (value, client_step_size) ) aggregate_distortion = distortion_aggregation_process.next( @@ -198,9 +198,9 @@ def next_fn(state, value): measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip(new_state), + state=federated_language.federated_zip(new_state), result=undiscretized_agg_value, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return aggregation_process.AggregationProcess(init_fn, next_fn) diff --git a/tensorflow_federated/python/aggregators/deterministic_discretization_test.py b/tensorflow_federated/python/aggregators/deterministic_discretization_test.py index 8fa8dc06a3..f3e7852433 100644 --- a/tensorflow_federated/python/aggregators/deterministic_discretization_test.py +++ b/tensorflow_federated/python/aggregators/deterministic_discretization_test.py @@ -16,6 +16,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -25,10 +26,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -61,7 +58,8 @@ def _named_test_cases_product(*args): _measurement_aggregator = measurements.add_measurements( - sum_factory.SumFactory(), client_measurement_fn=intrinsics.federated_sum + sum_factory.SumFactory(), + client_measurement_fn=federated_language.federated_sum, ) @@ -81,49 +79,49 @@ def test_type_properties(self, value_type): inner_agg_factory=_measurement_aggregator, distortion_aggregation_factory=mean.UnweightedMeanFactory(), ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) quantize_type = type_conversions.structure_from_tensor_type_tree( lambda x: (np.int32, x.shape), value_type ) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.StructType( + server_state_type = federated_language.StructType( [('step_size', np.float32), ('inner_agg_process', ())] ) - server_state_type = computation_types.FederatedType( - server_state_type, placements.SERVER + server_state_type = federated_language.FederatedType( + server_state_type, federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type ) - expected_measurements_type = computation_types.StructType([ + expected_measurements_type = federated_language.StructType([ ('deterministic_discretization', quantize_type), ('distortion', np.float32), ]) - expected_measurements_type = computation_types.FederatedType( - expected_measurements_type, placements.SERVER + expected_measurements_type = federated_language.FederatedType( + expected_measurements_type, federated_language.SERVER ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.next.type_signature, expected_next_type ) @@ -138,20 +136,20 @@ def test_raises_on_bad_component_tensor_dtypes(self, value_type): factory = deterministic_discretization.DeterministicDiscretizationFactory( inner_agg_factory=_measurement_aggregator, step_size=0.1 ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) self.assertRaises(TypeError, factory.create, value_type) @parameterized.named_parameters( ('plain_struct', [('a', np.int32)]), - ('sequence', computation_types.SequenceType(np.int32)), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('nested_sequence', [[[computation_types.SequenceType(np.int32)]]]), + ('sequence', federated_language.SequenceType(np.int32)), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('nested_sequence', [[[federated_language.SequenceType(np.int32)]]]), ) def test_raises_on_bad_tff_value_types(self, value_type): factory = deterministic_discretization.DeterministicDiscretizationFactory( inner_agg_factory=_measurement_aggregator, step_size=0.1 ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) self.assertRaises(TypeError, factory.create, value_type) @@ -192,7 +190,7 @@ def test_discretize_impl(self, value_type, client_values, expected_sum): step_size=0.1, distortion_aggregation_factory=mean.UnweightedMeanFactory(), ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) state = process.initialize() diff --git a/tensorflow_federated/python/aggregators/differential_privacy.py b/tensorflow_federated/python/aggregators/differential_privacy.py index a01243d2b5..88dc533380 100644 --- a/tensorflow_federated/python/aggregators/differential_privacy.py +++ b/tensorflow_federated/python/aggregators/differential_privacy.py @@ -21,6 +21,7 @@ from absl import logging import dp_accounting +import federated_language import tensorflow as tf import tensorflow_privacy as tfp @@ -29,10 +30,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -517,49 +514,59 @@ def get_noised_result(sample_state, global_state): lambda event: event.to_named_tuple(), dp_event_type ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - query_initial_state = intrinsics.federated_eval( - query_initial_state_fn, placements.SERVER + query_initial_state = federated_language.federated_eval( + query_initial_state_fn, federated_language.SERVER ) - query_sample_state = intrinsics.federated_eval( - query_sample_state_fn, placements.SERVER + query_sample_state = federated_language.federated_eval( + query_sample_state_fn, federated_language.SERVER ) - _, _, dp_event = intrinsics.federated_map( + _, _, dp_event = federated_language.federated_map( get_noised_result, (query_sample_state, query_initial_state) ) - dp_event = intrinsics.federated_map(convert_dp_event, dp_event) - is_init_state = intrinsics.federated_value(True, placements.SERVER) + dp_event = federated_language.federated_map(convert_dp_event, dp_event) + is_init_state = federated_language.federated_value( + True, federated_language.SERVER + ) init_state = DPAggregatorState( query_initial_state, record_agg_process.initialize(), dp_event, is_init_state, ) - return intrinsics.federated_zip(init_state) + return federated_language.federated_zip(init_state) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): query_state, agg_state, _, _ = state - params = intrinsics.federated_broadcast( - intrinsics.federated_map(derive_sample_params, query_state) + params = federated_language.federated_broadcast( + federated_language.federated_map(derive_sample_params, query_state) + ) + record = federated_language.federated_map( + get_query_record, (params, value) ) - record = intrinsics.federated_map(get_query_record, (params, value)) record_agg_output = record_agg_process.next(agg_state, record) - result, new_query_state, dp_event = intrinsics.federated_map( + result, new_query_state, dp_event = federated_language.federated_map( get_noised_result, (record_agg_output.result, query_state) ) - dp_event = intrinsics.federated_map(convert_dp_event, dp_event) + dp_event = federated_language.federated_map(convert_dp_event, dp_event) - is_init_state = intrinsics.federated_value(False, placements.SERVER) + is_init_state = federated_language.federated_value( + False, federated_language.SERVER + ) - query_metrics = intrinsics.federated_map(derive_metrics, new_query_state) + query_metrics = federated_language.federated_map( + derive_metrics, new_query_state + ) new_state = DPAggregatorState( new_query_state, @@ -571,9 +578,9 @@ def next_fn(state, value): dp_query_metrics=query_metrics, dp=record_agg_output.measurements ) return measured_process.MeasuredProcessOutput( - intrinsics.federated_zip(new_state), + federated_language.federated_zip(new_state), result, - intrinsics.federated_zip(measurements), + federated_language.federated_zip(measurements), ) return aggregation_process.AggregationProcess(init_fn, next_fn) diff --git a/tensorflow_federated/python/aggregators/differential_privacy_test.py b/tensorflow_federated/python/aggregators/differential_privacy_test.py index 05cc0739cb..58e7715a64 100644 --- a/tensorflow_federated/python/aggregators/differential_privacy_test.py +++ b/tensorflow_federated/python/aggregators/differential_privacy_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf import tensorflow_privacy as tfp @@ -24,8 +25,6 @@ from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -48,27 +47,27 @@ def test_type_properties(self, value_type, inner_agg_factory): _test_dp_query, inner_agg_factory ) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - query_state_type = computation_types.StructType( + query_state_type = federated_language.StructType( [('l2_norm_clip', np.float32), ('stddev', np.float32)] ) query_metrics_type = () inner_state_type = np.int32 if inner_agg_factory else () - dp_event_type = computation_types.StructType([ + dp_event_type = federated_language.StructType([ ('module_name', np.str_), ('class_name', np.str_), ('noise_multiplier', np.float32), ]) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( differential_privacy.DPAggregatorState( query_state_type, inner_state_type, dp_event_type, np.bool_ ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -77,24 +76,24 @@ def test_type_properties(self, value_type, inner_agg_factory): ) ) inner_measurements_type = np.int32 if inner_agg_factory else () - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( dp_query_metrics=query_metrics_type, dp=inner_measurements_type ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -116,10 +115,12 @@ def test_init_non_agg_factory_raises(self): @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_value_type_raises(self, bad_value_type): factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) @@ -161,7 +162,7 @@ def assertInnerSumPlusOnePerformed( def test_simple_sum(self): factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory_.create(value_type) # The test query has clip 1.0 and no noise, so this computes clipped sum. @@ -174,7 +175,7 @@ def test_simple_sum(self): def test_structure_sum(self): factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) - value_type = computation_types.to_type([np.float32, np.float32]) + value_type = federated_language.to_type([np.float32, np.float32]) process = factory_.create(value_type) # The test query has clip 1.0 and no noise, so this computes clipped sum. @@ -195,7 +196,7 @@ def test_structure_sum(self): self.assertAllClose(expected_result, output.result) def test_inner_sum(self): - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) factory_ = differential_privacy.DifferentiallyPrivateFactory( _test_dp_query, _test_inner_agg_factory ) @@ -212,7 +213,7 @@ def test_inner_sum(self): def test_tree_aggregation_inner_sum(self): l2_clip = 1.0 - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) tree_factory = ( differential_privacy.DifferentiallyPrivateFactory.tree_aggregation( noise_multiplier=0.0, @@ -242,7 +243,7 @@ def test_adaptive_query(self): geometric_update=False, ) factory_ = differential_privacy.DifferentiallyPrivateFactory(query) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory_.create(value_type) state = process.initialize() @@ -258,7 +259,7 @@ def test_adaptive_query(self): self.assertAllClose(expected_result, output.result) def test_extract_dp_event_from_state(self): - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) process = factory_.create(value_type) state = process.initialize() @@ -275,7 +276,7 @@ def test_extract_dp_event_from_state(self): self.assertEqual(event, expected_dp_event) def test_error_when_extracting_from_initial_state(self): - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) factory_ = differential_privacy.DifferentiallyPrivateFactory(_test_dp_query) process = factory_.create(value_type) state = process.initialize() @@ -289,7 +290,7 @@ def test_noise(self): factory_ = differential_privacy.DifferentiallyPrivateFactory.gaussian_fixed( noise_multiplier=noise, clients_per_round=1.0, clip=1.0 ) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory_.create(value_type) state = process.initialize() @@ -362,7 +363,7 @@ def test_tree_aggregation_factory( variable_shape, tolerance = [10000], 0.05 record = np.zeros(variable_shape, np.float32) record_shape = variable_shape - record_type = computation_types.to_type((np.float32, variable_shape)) + record_type = federated_language.to_type((np.float32, variable_shape)) specs = tf.TensorSpec(shape=record_shape, dtype=tf.float32) tree_factory = ( @@ -417,7 +418,7 @@ def test_tree_adaptive_factory_estimate_clip(self): clipped_count_stddev=0.0, noise_seed=1, ) - process = factory_.create(computation_types.TensorType(np.float32)) + process = factory_.create(federated_language.TensorType(np.float32)) state = process.initialize() diff --git a/tensorflow_federated/python/aggregators/discretization.py b/tensorflow_federated/python/aggregators/discretization.py index caaac56b83..4bc1d15f59 100644 --- a/tensorflow_federated/python/aggregators/discretization.py +++ b/tensorflow_federated/python/aggregators/discretization.py @@ -16,17 +16,13 @@ import collections import numbers +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -138,11 +134,11 @@ def __init__( def create(self, value_type): # Validate input args and value_type and parse out the TF dtypes. - if isinstance(value_type, computation_types.TensorType): + if isinstance(value_type, federated_language.TensorType): tf_dtype = value_type.dtype elif isinstance( - value_type, computation_types.StructWithPythonType - ) and type_analysis.is_structure_of_tensors(value_type): + value_type, federated_language.StructWithPythonType + ) and federated_language.framework.is_structure_of_tensors(value_type): if self._prior_norm_bound: raise TypeError( 'If `prior_norm_bound` is specified, `value_type` must ' @@ -159,7 +155,7 @@ def create(self, value_type): ) # Check that all values are floats. - if not type_analysis.is_structure_of_floats(value_type): + if not federated_language.framework.is_structure_of_floats(value_type): raise TypeError( 'Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.' @@ -178,37 +174,43 @@ def undiscretize_fn(value, scale_factor): inner_value_type = discretize_fn.type_signature.result inner_agg_process = self._inner_agg_factory.create(inner_value_type) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): state = collections.OrderedDict( - scale_factor=intrinsics.federated_value( - self._scale_factor, placements.SERVER + scale_factor=federated_language.federated_value( + self._scale_factor, federated_language.SERVER ), - prior_norm_bound=intrinsics.federated_value( - self._prior_norm_bound, placements.SERVER + prior_norm_bound=federated_language.federated_value( + self._prior_norm_bound, federated_language.SERVER ), inner_agg_process=inner_agg_process.initialize(), ) - return intrinsics.federated_zip(state) + return federated_language.federated_zip(state) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): server_scale_factor = state['scale_factor'] - client_scale_factor = intrinsics.federated_broadcast(server_scale_factor) + client_scale_factor = federated_language.federated_broadcast( + server_scale_factor + ) server_prior_norm_bound = state['prior_norm_bound'] - prior_norm_bound = intrinsics.federated_broadcast(server_prior_norm_bound) + prior_norm_bound = federated_language.federated_broadcast( + server_prior_norm_bound + ) - discretized_value = intrinsics.federated_map( + discretized_value = federated_language.federated_map( discretize_fn, (value, client_scale_factor, prior_norm_bound) ) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) - undiscretized_agg_value = intrinsics.federated_map( + undiscretized_agg_value = federated_language.federated_map( undiscretize_fn, (inner_agg_output.result, server_scale_factor) ) @@ -222,9 +224,9 @@ def next_fn(state, value): ) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip(new_state), + state=federated_language.federated_zip(new_state), result=undiscretized_agg_value, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return aggregation_process.AggregationProcess(init_fn, next_fn) diff --git a/tensorflow_federated/python/aggregators/discretization_test.py b/tensorflow_federated/python/aggregators/discretization_test.py index a3aa95a1b2..4cb5b55c8f 100644 --- a/tensorflow_federated/python/aggregators/discretization_test.py +++ b/tensorflow_federated/python/aggregators/discretization_test.py @@ -15,15 +15,13 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import discretization from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -82,45 +80,45 @@ class DiscretizationFactoryComputationTest( ) def test_type_properties(self, value_type): factory = _discretization_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( scale_factor=np.float32, prior_norm_bound=np.float32, inner_agg_process=(), ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type ) - expected_measurements_type = computation_types.FederatedType( - collections.OrderedDict(discretize=()), placements.SERVER + expected_measurements_type = federated_language.FederatedType( + collections.OrderedDict(discretize=()), federated_language.SERVER ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.next.type_signature, expected_next_type ) @@ -133,19 +131,19 @@ def test_type_properties(self, value_type): ) def test_raises_on_bad_component_tensor_dtypes(self, value_type): factory = _discretization_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'must all be floats'): factory.create(value_type) @parameterized.named_parameters( ('plain_struct', [('a', np.int32)]), - ('sequence', computation_types.SequenceType(np.int32)), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('nested_sequence', [[[computation_types.SequenceType(np.int32)]]]), + ('sequence', federated_language.SequenceType(np.int32)), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('nested_sequence', [[[federated_language.SequenceType(np.int32)]]]), ) def test_raises_on_bad_tff_value_types(self, value_type): factory = _discretization_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): factory.create(value_type) @@ -226,7 +224,7 @@ def test_sum(self, value_type, client_data, expected_sum, stochastic): """Integration test with sum.""" scale_factor = 3 factory = _discretization_sum(scale_factor, stochastic=stochastic) - process = factory.create(computation_types.to_type(value_type)) + process = factory.create(federated_language.to_type(value_type)) state = process.initialize() for _ in range(3): diff --git a/tensorflow_federated/python/aggregators/distributed_dp.py b/tensorflow_federated/python/aggregators/distributed_dp.py index 14189d491b..03b2dcb70f 100644 --- a/tensorflow_federated/python/aggregators/distributed_dp.py +++ b/tensorflow_federated/python/aggregators/distributed_dp.py @@ -18,6 +18,7 @@ from typing import Optional import warnings +import federated_language import numpy as np import tensorflow as tf import tensorflow_privacy as tfp @@ -33,12 +34,6 @@ from tensorflow_federated.python.aggregators import secure from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -555,10 +550,10 @@ def _update_dp_params(agg_state, new_l2_clip): # NOTE(b/170893510): Explicitly declaring Union[float, EstimationProcess] # for _l2_clip or doing isinstance() check still triggers attribute-error. new_l2_clip = self._l2_clip.report(l2_clip_state['clipping_norm']) # pytype: disable=attribute-error - agg_state = intrinsics.federated_map( + agg_state = federated_language.federated_map( _update_scale, (agg_state, new_l2_clip) ) - agg_state = intrinsics.federated_map( + agg_state = federated_language.federated_map( _update_dp_params, (agg_state, new_l2_clip) ) return agg_state @@ -568,11 +563,12 @@ def _derive_measurements(self, agg_state, agg_measurements): l2_clip_metrics, _, dp_metrics = self._unpack_measurements(agg_measurements) dp_query_state, _, _, _ = dp_state - actual_num_clients = intrinsics.federated_secure_sum_bitwidth( - intrinsics.federated_value(1, placements.CLIENTS), bitwidth=1 + actual_num_clients = federated_language.federated_secure_sum_bitwidth( + federated_language.federated_value(1, federated_language.CLIENTS), + bitwidth=1, ) - padded_dim = intrinsics.federated_value( - int(self._padded_dim), placements.SERVER + padded_dim = federated_language.federated_value( + int(self._padded_dim), federated_language.SERVER ) measurements = collections.OrderedDict( @@ -585,19 +581,22 @@ def _derive_measurements(self, agg_state, agg_measurements): dp_query_metrics=dp_metrics['dp_query_metrics'], ) - return intrinsics.federated_zip(measurements) + return federated_language.federated_zip(measurements) def create(self, value_type): # Checks value_type and compute client data dimension. if isinstance( - value_type, computation_types.StructWithPythonType - ) and type_analysis.is_structure_of_tensors(value_type): + value_type, federated_language.StructWithPythonType + ) and federated_language.framework.is_structure_of_tensors(value_type): num_elements_struct = type_conversions.structure_from_tensor_type_tree( - lambda x: array_shape.num_elements_in_shape(x.shape), value_type + lambda x: federated_language.num_elements_in_array_shape(x.shape), + value_type, ) self._client_dim = sum(tf.nest.flatten(num_elements_struct)) - elif isinstance(value_type, computation_types.TensorType): - self._client_dim = array_shape.num_elements_in_shape(value_type.shape) + elif isinstance(value_type, federated_language.TensorType): + self._client_dim = federated_language.num_elements_in_array_shape( + value_type.shape + ) else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' @@ -606,8 +605,8 @@ def create(self, value_type): ) # Checks that all values are integers or floats. if not ( - type_analysis.is_structure_of_floats(value_type) - or type_analysis.is_structure_of_integers(value_type) + federated_language.framework.is_structure_of_floats(value_type) + or federated_language.framework.is_structure_of_integers(value_type) ): raise TypeError( 'Component dtypes of `value_type` must all be integers ' @@ -617,9 +616,11 @@ def create(self, value_type): ddp_agg_process = self._build_aggregation_factory().create(value_type) init_fn = ddp_agg_process.initialize - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): agg_output = ddp_agg_process.next(state, value) diff --git a/tensorflow_federated/python/aggregators/distributed_dp_test.py b/tensorflow_federated/python/aggregators/distributed_dp_test.py index 3a33a462e5..3d8e827a62 100644 --- a/tensorflow_federated/python/aggregators/distributed_dp_test.py +++ b/tensorflow_federated/python/aggregators/distributed_dp_test.py @@ -16,6 +16,7 @@ import types from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf import tensorflow_privacy as tfp @@ -29,12 +30,9 @@ from tensorflow_federated.python.aggregators import rotation from tensorflow_federated.python.aggregators import secure from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.core.test import static_assert _test_struct_type = [(np.float32, (2, 2)), np.float32] _test_nested_struct_type = collections.OrderedDict( @@ -130,7 +128,7 @@ class DistributedDpComputationTest(tf.test.TestCase, parameterized.TestCase): def test_type_properties(self, value_type, mechanism): ddp_factory = _make_test_factory(mechanism=mechanism) self.assertIsInstance(ddp_factory, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = ddp_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) @@ -172,26 +170,28 @@ def test_type_properties(self, value_type, mechanism): padded_dim=np.int32, dp_query_metrics=dp_query_metrics_type, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=expected_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), - measurements=computation_types.FederatedType( - expected_measurements_type, placements.SERVER + measurements=federated_language.FederatedType( + expected_measurements_type, federated_language.SERVER ), ), ) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) @parameterized.named_parameters( ('negative', -1, ValueError), @@ -312,13 +312,13 @@ def test_auto_l2_clip_count_stddev_raise_on(self, stddev, error_type): @parameterized.named_parameters( ('plain_struct', [('a', np.int32)]), - ('sequence', computation_types.SequenceType(np.int32)), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('nested_sequence', [[[computation_types.SequenceType(np.int32)]]]), + ('sequence', federated_language.SequenceType(np.int32)), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('nested_sequence', [[[federated_language.SequenceType(np.int32)]]]), ) def test_tff_value_types_raise_on(self, value_type): ddp_factory = _make_test_factory() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): ddp_factory.create(value_type) @@ -329,7 +329,7 @@ def test_tff_value_types_raise_on(self, value_type): ) def test_component_tensor_dtypes_raise_on(self, value_type): test_factory = _make_test_factory() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'must all be integers or floats'): test_factory.create(value_type) @@ -374,7 +374,7 @@ def test_sum(self, name, rotation_type, beta): l2_clip=10.0, beta=beta, ) - process = ddp_factory.create(computation_types.to_type(value_type)) + process = ddp_factory.create(federated_language.to_type(value_type)) state = process.initialize() for _ in range(2): output = process.next(state, client_values) @@ -410,7 +410,7 @@ def test_auto_tuning(self, mechanism, rotation_type): dim = 99 padded_dim = 100.0 if rotation_type == 'dft' else 128.0 value_type = (np.float32, _make_onehot(0.0, dim).shape) - process = ddp_factory.create(computation_types.to_type(value_type)) + process = ddp_factory.create(federated_language.to_type(value_type)) state = process.initialize() _, discrete_state, _ = ddp_factory._unpack_state(state) cur_scale = discrete_state['scale_factor'] @@ -524,7 +524,7 @@ def test_noisy_sum(self, mechanism): bits=20, mechanism=mechanism, ) - process = ddp_factory.create(computation_types.TensorType(np.float32)) + process = ddp_factory.create(federated_language.TensorType(np.float32)) state = process.initialize() outputs = [] for _ in range(num_iterations): diff --git a/tensorflow_federated/python/aggregators/encoded.py b/tensorflow_federated/python/aggregators/encoded.py index 6a4533d873..487e3119c4 100644 --- a/tensorflow_federated/python/aggregators/encoded.py +++ b/tensorflow_federated/python/aggregators/encoded.py @@ -17,6 +17,7 @@ from collections.abc import Callable import typing +import federated_language import tensorflow as tf import tree @@ -24,11 +25,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_model_optimization.python.core.internal import tensor_encoding as te @@ -115,7 +111,10 @@ def quantize_above_threshold( _check_threshold(threshold) def encoder_fn(value_spec): - if array_shape.num_elements_in_shape(value_spec.shape) > threshold: + if ( + federated_language.num_elements_in_array_shape(value_spec.shape) + > threshold + ): return te.encoders.as_gather_encoder( te.encoders.uniform_quantization(quantization_bits, **kwargs), value_spec, @@ -157,8 +156,10 @@ def _encoded_init_fn(encoders): init_fn_tf = tensorflow_computation.tf_computation( lambda: tf.nest.map_structure(lambda e: e.initial_state(), encoders) ) - init_fn = federated_computation.federated_computation( - lambda: intrinsics.federated_eval(init_fn_tf, placements.SERVER) + init_fn = federated_language.federated_computation( + lambda: federated_language.federated_eval( + init_fn_tf, federated_language.SERVER + ) ) return init_fn @@ -320,36 +321,38 @@ def merge_fn(acc1, acc2): def report_fn(acc): return acc - @federated_computation.federated_computation( + @federated_language.federated_computation( server_state_type, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType(value_type, federated_language.CLIENTS), ) def next_fn(state, value): encode_params, decode_before_sum_params, decode_after_sum_params = ( - intrinsics.federated_map(get_params_fn, state) + federated_language.federated_map(get_params_fn, state) ) - encode_params = intrinsics.federated_broadcast(encode_params) - decode_before_sum_params = intrinsics.federated_broadcast( + encode_params = federated_language.federated_broadcast(encode_params) + decode_before_sum_params = federated_language.federated_broadcast( decode_before_sum_params ) - encoded_values = intrinsics.federated_map( + encoded_values = federated_language.federated_map( encode_fn, [value, encode_params, decode_before_sum_params] ) - aggregated_values = intrinsics.federated_aggregate( + aggregated_values = federated_language.federated_aggregate( encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn ) - decoded_values = intrinsics.federated_map( + decoded_values = federated_language.federated_map( decode_after_sum_fn, [aggregated_values.values, decode_after_sum_params] ) - updated_state = intrinsics.federated_map( + updated_state = federated_language.federated_map( update_state_fn, [state, aggregated_values.state_update_tensors] ) - empty_metrics = intrinsics.federated_value((), placements.SERVER) + empty_metrics = federated_language.federated_value( + (), federated_language.SERVER + ) return measured_process.MeasuredProcessOutput( state=updated_state, result=decoded_values, measurements=empty_metrics ) diff --git a/tensorflow_federated/python/aggregators/encoded_test.py b/tensorflow_federated/python/aggregators/encoded_test.py index 83d564f37d..aa79a52fca 100644 --- a/tensorflow_federated/python/aggregators/encoded_test.py +++ b/tensorflow_federated/python/aggregators/encoded_test.py @@ -16,14 +16,13 @@ import random from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import encoded from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_model_optimization.python.core.internal import tensor_encoding as te @@ -59,7 +58,9 @@ def _state_update_encoder_fn(value_spec): ) -_test_struct_type = computation_types.to_type(((np.float32, (20,)), np.float32)) +_test_struct_type = federated_language.to_type( + ((np.float32, (20,)), np.float32) +) class EncodedSumFactoryComputationTest( @@ -84,21 +85,23 @@ def test_type_properties(self, encoder_fn): server_state_type = process.initialize.type_signature.result # State structure should have one element per tensor aggregated, self.assertLen(server_state_type.member, 2) - self.assertEqual(placements.SERVER, server_state_type.placement) + self.assertEqual(federated_language.SERVER, server_state_type.placement) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - _test_struct_type, placements.CLIENTS + value=federated_language.FederatedType( + _test_struct_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - _test_struct_type, placements.SERVER + result=federated_language.FederatedType( + _test_struct_type, federated_language.SERVER + ), + measurements=federated_language.FederatedType( + (), federated_language.SERVER ), - measurements=computation_types.FederatedType((), placements.SERVER), ), ) self.assertTrue( @@ -129,7 +132,7 @@ class EncodedSumFactoryExecutionTest(tf.test.TestCase): def test_simple_sum(self): encoded_f = encoded.EncodedSumFactory(_identity_encoder_fn) - process = encoded_f.create(computation_types.to_type(np.float32)) + process = encoded_f.create(federated_language.to_type(np.float32)) state = process.initialize() @@ -143,7 +146,7 @@ def test_simple_sum(self): def test_structure_sum(self): encoded_f = encoded.EncodedSumFactory(_identity_encoder_fn) process = encoded_f.create( - computation_types.to_type(((np.float32, (2,)), np.float32)) + federated_language.to_type(((np.float32, (2,)), np.float32)) ) state = process.initialize() @@ -163,7 +166,7 @@ def test_quantize_above_threshold_zero(self): encoded_f = encoded.EncodedSumFactory.quantize_above_threshold( quantization_bits=1, threshold=0 ) - test_type = computation_types.to_type( + test_type = federated_language.to_type( [(np.float32, (3,)), (np.float32, (5,))] ) process = encoded_f.create(test_type) @@ -179,7 +182,7 @@ def test_quantize_above_threshold_positive(self): encoded_f = encoded.EncodedSumFactory.quantize_above_threshold( quantization_bits=1, threshold=4 ) - test_type = computation_types.to_type( + test_type = federated_language.to_type( [(np.float32, (3,)), (np.float32, (5,))] ) process = encoded_f.create(test_type) @@ -197,7 +200,7 @@ def test_quantize_above_threshold(self): quantization_bits=4, threshold=0 ) process = encoded_f.create( - computation_types.to_type((np.float32, (10000,))) + federated_language.to_type((np.float32, (10000,))) ) # Creates random values in range [0., 15.] plus the bondaries exactly. diff --git a/tensorflow_federated/python/aggregators/factory.py b/tensorflow_federated/python/aggregators/factory.py index 90efa74096..8b593acf86 100644 --- a/tensorflow_federated/python/aggregators/factory.py +++ b/tensorflow_federated/python/aggregators/factory.py @@ -16,10 +16,11 @@ import abc from typing import Union -from tensorflow_federated.python.core.impl.types import computation_types +import federated_language + from tensorflow_federated.python.core.templates import aggregation_process -ValueType = Union[computation_types.TensorType, computation_types.StructType] +ValueType = Union[federated_language.TensorType, federated_language.StructType] class UnweightedAggregationFactory(abc.ABC): diff --git a/tensorflow_federated/python/aggregators/factory_utils.py b/tensorflow_federated/python/aggregators/factory_utils.py index 8b7a7aa123..fafd2ddf3c 100644 --- a/tensorflow_federated/python/aggregators/factory_utils.py +++ b/tensorflow_federated/python/aggregators/factory_utils.py @@ -13,11 +13,9 @@ # limitations under the License. """Utilities for building aggregation factories.""" +import federated_language from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process @@ -52,10 +50,14 @@ def __init__(self, unweighted_factory: factory.UnweightedAggregationFactory): def create(self, value_type, weight_type): aggregator = self._factory.create(value_type) - @federated_computation.federated_computation( + @federated_language.federated_computation( aggregator.state_type, - computation_types.FederatedType(value_type, placements.CLIENTS), - computation_types.FederatedType(weight_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) def next_fn(state, value, weight): del weight # Unused. diff --git a/tensorflow_federated/python/aggregators/factory_utils_test.py b/tensorflow_federated/python/aggregators/factory_utils_test.py index a97259d3d5..d89aae1888 100644 --- a/tensorflow_federated/python/aggregators/factory_utils_test.py +++ b/tensorflow_federated/python/aggregators/factory_utils_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import federated_language import numpy as np import tensorflow as tf @@ -20,10 +21,9 @@ from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -_TEST_VALUE_TYPE = computation_types.TensorType(np.float32, (2,)) -_TEST_WEIGHT_TYPE = computation_types.TensorType(np.float32) +_TEST_VALUE_TYPE = federated_language.TensorType(np.float32, (2,)) +_TEST_WEIGHT_TYPE = federated_language.TensorType(np.float32) class UnweightedAsWeightedAggregationTest(tf.test.TestCase): diff --git a/tensorflow_federated/python/aggregators/mean.py b/tensorflow_federated/python/aggregators/mean.py index 0b8ce5a9e5..f491e09455 100644 --- a/tensorflow_federated/python/aggregators/mean.py +++ b/tensorflow_federated/python/aggregators/mean.py @@ -17,6 +17,7 @@ import typing from typing import Optional +import federated_language import numpy as np import tensorflow as tf @@ -24,11 +25,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -108,22 +104,26 @@ def create( value_sum_process = self._value_sum_factory.create(value_type) weight_sum_process = self._weight_sum_factory.create(weight_type) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): state = collections.OrderedDict( value_sum_process=value_sum_process.initialize(), weight_sum_process=weight_sum_process.initialize(), ) - return intrinsics.federated_zip(state) + return federated_language.federated_zip(state) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), - computation_types.FederatedType(weight_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) def next_fn(state, value, weight): # Client computation. - weighted_value = intrinsics.federated_map(_mul, (value, weight)) + weighted_value = federated_language.federated_map(_mul, (value, weight)) # Inner aggregations. value_output = value_sum_process.next( @@ -134,7 +134,7 @@ def next_fn(state, value, weight): ) # Server computation. - weighted_mean_value = intrinsics.federated_map( + weighted_mean_value = federated_language.federated_map( _div_no_nan if self._no_nan_division else _div, (value_output.result, weight_output.result), ) @@ -149,9 +149,9 @@ def next_fn(state, value, weight): mean_weight=weight_output.measurements, ) return measured_process.MeasuredProcessOutput( - intrinsics.federated_zip(state), + federated_language.federated_zip(state), weighted_mean_value, - intrinsics.federated_zip(measurements), + federated_language.federated_zip(measurements), ) return aggregation_process.AggregationProcess(init_fn, next_fn) @@ -211,33 +211,36 @@ def create( _check_value_type(value_type) value_sum_process = self._value_sum_factory.create(value_type) count_sum_process = self._count_sum_factory.create( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_zip( + return federated_language.federated_zip( (value_sum_process.initialize(), count_sum_process.initialize()) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): value_sum_state, count_sum_state = state value_sum_output = value_sum_process.next(value_sum_state, value) count_sum_output = count_sum_process.next( - count_sum_state, intrinsics.federated_value(1, placements.CLIENTS) + count_sum_state, + federated_language.federated_value(1, federated_language.CLIENTS), ) - mean_value = intrinsics.federated_map( + mean_value = federated_language.federated_map( _div, (value_sum_output.result, count_sum_output.result) ) - state = intrinsics.federated_zip( + state = federated_language.federated_zip( (value_sum_output.state, count_sum_output.state) ) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict( mean_value=value_sum_output.measurements, mean_count=count_sum_output.measurements, @@ -254,7 +257,7 @@ def next_fn(state, value): def _check_value_type(value_type): type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) - if not type_analysis.is_structure_of_floats(value_type): + if not federated_language.framework.is_structure_of_floats(value_type): raise TypeError( 'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}' diff --git a/tensorflow_federated/python/aggregators/mean_test.py b/tensorflow_federated/python/aggregators/mean_test.py index d9e2080cbd..a976c33476 100644 --- a/tensorflow_federated/python/aggregators/mean_test.py +++ b/tensorflow_federated/python/aggregators/mean_test.py @@ -16,6 +16,7 @@ import math from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -23,8 +24,6 @@ from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -46,8 +45,8 @@ class MeanFactoryComputationTest(tf.test.TestCase, parameterized.TestCase): ('struct_value_int64_weight', _test_struct_type, np.int64), ) def test_type_properties(self, value_type, weight_type): - value_type = computation_types.to_type(value_type) - weight_type = computation_types.to_type(weight_type) + value_type = federated_language.to_type(value_type) + weight_type = federated_language.to_type(weight_type) factory_ = mean.MeanFactory() self.assertIsInstance(factory_, factory.WeightedAggregationFactory) @@ -55,23 +54,23 @@ def test_type_properties(self, value_type, weight_type): self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( collections.OrderedDict(value_sum_process=(), weight_sum_process=()), - placements.SERVER, + federated_language.SERVER, ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict(mean_value=(), mean_weight=()), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -83,10 +82,12 @@ def test_type_properties(self, value_type, weight_type): expected_parameter = collections.OrderedDict( state=expected_state_type, value=param_value_type, - weight=computation_types.FederatedType(weight_type, placements.CLIENTS), + weight=federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=expected_parameter, result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type @@ -101,7 +102,7 @@ def test_type_properties(self, value_type, weight_type): ('struct_value', _test_struct_type), ) def test_type_properties_unweighted(self, value_type): - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) factory_ = mean.UnweightedMeanFactory() self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) @@ -109,21 +110,22 @@ def test_type_properties_unweighted(self, value_type): self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType( - ((), ()), placements.SERVER + expected_state_type = federated_language.FederatedType( + ((), ()), federated_language.SERVER ) - expected_measurements_type = computation_types.FederatedType( - collections.OrderedDict(mean_value=(), mean_count=()), placements.SERVER + expected_measurements_type = federated_language.FederatedType( + collections.OrderedDict(mean_value=(), mean_count=()), + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -132,7 +134,7 @@ def test_type_properties_unweighted(self, value_type): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type ), @@ -155,8 +157,8 @@ def test_type_properties_unweighted(self, value_type): ('struct_value_int64_weight', _test_struct_type, np.int64), ) def test_type_properties_with_inner_factory(self, value_type, weight_type): - value_type = computation_types.to_type(value_type) - weight_type = computation_types.to_type(weight_type) + value_type = federated_language.to_type(value_type) + weight_type = federated_language.to_type(weight_type) sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.MeanFactory( @@ -167,25 +169,25 @@ def test_type_properties_with_inner_factory(self, value_type, weight_type): self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( collections.OrderedDict( value_sum_process=np.int32, weight_sum_process=np.int32 ), - placements.SERVER, + federated_language.SERVER, ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict(mean_value=np.int32, mean_weight=np.int32), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -197,10 +199,12 @@ def test_type_properties_with_inner_factory(self, value_type, weight_type): expected_parameter = collections.OrderedDict( state=expected_state_type, value=param_value_type, - weight=computation_types.FederatedType(weight_type, placements.CLIENTS), + weight=federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=expected_parameter, result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type @@ -215,7 +219,7 @@ def test_type_properties_with_inner_factory(self, value_type, weight_type): ('struct_value', _test_struct_type), ) def test_type_properties_with_inner_factory_unweighted(self, value_type): - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.UnweightedMeanFactory( @@ -226,22 +230,22 @@ def test_type_properties_with_inner_factory_unweighted(self, value_type): self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType( - ((np.int32, np.int32)), placements.SERVER + expected_state_type = federated_language.FederatedType( + ((np.int32, np.int32)), federated_language.SERVER ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict(mean_value=np.int32, mean_count=np.int32), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -250,7 +254,7 @@ def test_type_properties_with_inner_factory_unweighted(self, value_type): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type ), @@ -265,14 +269,16 @@ def test_type_properties_with_inner_factory_unweighted(self, value_type): @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_create_type_raises(self, wrong_type): factory_ = mean.MeanFactory() - correct_type = computation_types.TensorType(np.float32) + correct_type = federated_language.TensorType(np.float32) with self.assertRaises(TypeError): factory_.create(wrong_type, correct_type) with self.assertRaises(TypeError): @@ -287,8 +293,8 @@ class MeanFactoryExecutionTest(tf.test.TestCase): def test_scalar_value(self): factory_ = mean.MeanFactory() - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) expected_state = collections.OrderedDict( @@ -311,7 +317,7 @@ def test_scalar_value(self): def test_scalar_value_unweighted(self): factory_ = mean.UnweightedMeanFactory() - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory_.create(value_type) expected_state = ((), ()) @@ -331,8 +337,8 @@ def test_scalar_value_unweighted(self): def test_structure_value(self): factory_ = mean.MeanFactory() - value_type = computation_types.to_type(_test_struct_type) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.to_type(_test_struct_type) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) expected_state = collections.OrderedDict( value_sum_process=(), weight_sum_process=() @@ -353,7 +359,7 @@ def test_structure_value(self): def test_structure_value_unweighted(self): factory_ = mean.UnweightedMeanFactory() - value_type = computation_types.to_type(_test_struct_type) + value_type = federated_language.to_type(_test_struct_type) process = factory_.create(value_type) expected_state = ((), ()) expected_measurements = collections.OrderedDict( @@ -372,8 +378,8 @@ def test_structure_value_unweighted(self): def test_weight_arg(self): factory_ = mean.MeanFactory() - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) state = process.initialize() @@ -387,8 +393,8 @@ def test_weight_arg(self): def test_weight_arg_all_zeros_nan_division(self): factory_ = mean.MeanFactory(no_nan_division=False) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) state = process.initialize() @@ -402,8 +408,8 @@ def test_weight_arg_all_zeros_nan_division(self): def test_weight_arg_all_zeros_no_nan_division(self): factory_ = mean.MeanFactory(no_nan_division=True) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) state = process.initialize() @@ -416,8 +422,8 @@ def test_weight_arg_all_zeros_no_nan_division(self): def test_inner_value_sum_factory(self): sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.MeanFactory(value_sum_factory=sum_factory) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) state = process.initialize() @@ -446,7 +452,7 @@ def test_inner_value_sum_factory_unweighted(self): factory_ = mean.UnweightedMeanFactory( value_sum_factory=sum_factory, count_sum_factory=sum_factory ) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory_.create(value_type) state = process.initialize() @@ -466,8 +472,8 @@ def test_inner_value_sum_factory_unweighted(self): def test_inner_weight_sum_factory(self): sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.MeanFactory(weight_sum_factory=sum_factory) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) state = process.initialize() @@ -496,8 +502,8 @@ def test_inner_value_and_weight_sum_factory(self): factory_ = mean.MeanFactory( value_sum_factory=sum_factory, weight_sum_factory=sum_factory ) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory_.create(value_type, weight_type) state = process.initialize() diff --git a/tensorflow_federated/python/aggregators/measurements.py b/tensorflow_federated/python/aggregators/measurements.py index dc266bb4b2..21c18ed73e 100644 --- a/tensorflow_federated/python/aggregators/measurements.py +++ b/tensorflow_federated/python/aggregators/measurements.py @@ -18,13 +18,11 @@ import typing from typing import Any, Optional +import federated_language + from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -112,22 +110,26 @@ def create( inner_agg_process = inner_agg_factory.create(value_type, weight_type) init_fn = inner_agg_process.initialize - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), - computation_types.FederatedType(weight_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) def next_fn(state, value, weight): inner_agg_output = inner_agg_process.next(state, value, weight) measurements = inner_agg_output.measurements if client_measurement_fn: client_measurements = client_measurement_fn(value, weight) - measurements = intrinsics.federated_map( + measurements = federated_language.federated_map( dict_update, (measurements, client_measurements) ) if server_measurement_fn: server_measurements = server_measurement_fn(inner_agg_output.result) - measurements = intrinsics.federated_map( + measurements = federated_language.federated_map( dict_update, (measurements, server_measurements) ) return measured_process.MeasuredProcessOutput( @@ -153,21 +155,23 @@ def create( inner_agg_process = inner_agg_factory.create(value_type) init_fn = inner_agg_process.initialize - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): inner_agg_output = inner_agg_process.next(state, value) measurements = inner_agg_output.measurements if client_measurement_fn: client_measurements = client_measurement_fn(value) - measurements = intrinsics.federated_map( + measurements = federated_language.federated_map( dict_update, (measurements, client_measurements) ) if server_measurement_fn: server_measurements = server_measurement_fn(inner_agg_output.result) - measurements = intrinsics.federated_map( + measurements = federated_language.federated_map( dict_update, (measurements, server_measurements) ) return measured_process.MeasuredProcessOutput( diff --git a/tensorflow_federated/python/aggregators/measurements_test.py b/tensorflow_federated/python/aggregators/measurements_test.py index e1a47facdf..abc09fe57a 100644 --- a/tensorflow_federated/python/aggregators/measurements_test.py +++ b/tensorflow_federated/python/aggregators/measurements_test.py @@ -14,6 +14,7 @@ import collections +import federated_language import numpy as np import tensorflow as tf @@ -22,15 +23,12 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -_struct_type = computation_types.to_type([(np.float32, (3,)), np.float32]) -_struct_type_clients = computation_types.FederatedType( - _struct_type, placements.CLIENTS +_struct_type = federated_language.to_type([(np.float32, (3,)), np.float32]) +_struct_type_clients = federated_language.FederatedType( + _struct_type, federated_language.CLIENTS ) -_float_type = computation_types.to_type(np.float32) +_float_type = federated_language.to_type(np.float32) @tensorflow_computation.tf_computation @@ -44,8 +42,8 @@ def _make_struct(x): def _get_min_norm(value): - norms = intrinsics.federated_map(_get_norm, value) - min_norm = intrinsics.federated_min(norms) + norms = federated_language.federated_map(_get_norm, value) + min_norm = federated_language.federated_min(norms) return collections.OrderedDict(min_norm=min_norm) @@ -55,14 +53,16 @@ def _mul_struct(value, weight): def _get_min_weighted_norm(value, weight): - weighted_value = intrinsics.federated_map(_mul_struct, (value, weight)) - norms = intrinsics.federated_map(_get_norm, weighted_value) - min_weighted_norm = intrinsics.federated_min(norms) + weighted_value = federated_language.federated_map( + _mul_struct, (value, weight) + ) + norms = federated_language.federated_map(_get_norm, weighted_value) + min_weighted_norm = federated_language.federated_min(norms) return collections.OrderedDict(min_weighted_norm=min_weighted_norm) def _get_server_norm(value): - server_norm = intrinsics.federated_map(_get_norm, value) + server_norm = federated_language.federated_map(_get_norm, value) return collections.OrderedDict(server_norm=server_norm) diff --git a/tensorflow_federated/python/aggregators/modular_clipping.py b/tensorflow_federated/python/aggregators/modular_clipping.py index d6b317ce57..9468635f6c 100644 --- a/tensorflow_federated/python/aggregators/modular_clipping.py +++ b/tensorflow_federated/python/aggregators/modular_clipping.py @@ -18,18 +18,13 @@ from typing import Optional import warnings +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -153,14 +148,17 @@ def create( ) -> aggregation_process.AggregationProcess: # Checks value_type and compute client data dimension. if isinstance( - value_type, computation_types.StructType - ) and type_analysis.is_structure_of_tensors(value_type): + value_type, federated_language.StructType + ) and federated_language.framework.is_structure_of_tensors(value_type): num_elements_struct = type_conversions.structure_from_tensor_type_tree( - lambda x: array_shape.num_elements_in_shape(x.shape), value_type + lambda x: federated_language.num_elements_in_array_shape(x.shape), + value_type, ) client_dim = sum(tf.nest.flatten(num_elements_struct)) - elif isinstance(value_type, computation_types.TensorType): - client_dim = array_shape.num_elements_in_shape(value_type.shape) + elif isinstance(value_type, federated_language.TensorType): + client_dim = federated_language.num_elements_in_array_shape( + value_type.shape + ) else: raise TypeError( 'Expected `value_type` to be `TensorType` or ' @@ -168,7 +166,7 @@ def create( f'Found type: {repr(value_type)}' ) # Checks that all values are integers. - if not type_analysis.is_structure_of_integers(value_type): + if not federated_language.framework.is_structure_of_integers(value_type): raise TypeError( 'Component dtypes of `value_type` must all be integers. ' f'Found {repr(value_type)}.' @@ -204,32 +202,34 @@ def _create_next_fn(self, inner_agg_next, state_type, value_type): estimate_wrapped_gaussian_stddev ) - @federated_computation.federated_computation( + @federated_language.federated_computation( state_type, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): - clip_lower = intrinsics.federated_value( - self._clip_range_lower, placements.SERVER + clip_lower = federated_language.federated_value( + self._clip_range_lower, federated_language.SERVER ) - clip_upper = intrinsics.federated_value( - self._clip_range_upper, placements.SERVER + clip_upper = federated_language.federated_value( + self._clip_range_upper, federated_language.SERVER ) # Modular clip values before aggregation. - clipped_value = intrinsics.federated_map( + clipped_value = federated_language.federated_map( modular_clip_by_value_fn, ( value, - intrinsics.federated_broadcast(clip_lower), - intrinsics.federated_broadcast(clip_upper), + federated_language.federated_broadcast(clip_lower), + federated_language.federated_broadcast(clip_upper), ), ) inner_agg_output = inner_agg_next(state, clipped_value) # Clip the aggregate to the same range again (not considering summands). - clipped_agg_output_result = intrinsics.federated_map( + clipped_agg_output_result = federated_language.federated_map( modular_clip_by_value_fn, (inner_agg_output.result, clip_lower, clip_upper), ) @@ -239,7 +239,7 @@ def next_fn(state, value): ) if self._estimate_stddev: - estimate = intrinsics.federated_map( + estimate = federated_language.federated_map( estimator_fn, (clipped_agg_output_result, clip_lower, clip_upper) ) measurements['estimated_stddev'] = estimate @@ -247,7 +247,7 @@ def next_fn(state, value): return measured_process.MeasuredProcessOutput( state=inner_agg_output.state, result=clipped_agg_output_result, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return next_fn diff --git a/tensorflow_federated/python/aggregators/modular_clipping_test.py b/tensorflow_federated/python/aggregators/modular_clipping_test.py index 9f30b0908f..96164fa8a5 100644 --- a/tensorflow_federated/python/aggregators/modular_clipping_test.py +++ b/tensorflow_federated/python/aggregators/modular_clipping_test.py @@ -15,21 +15,24 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import modular_clipping from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process _test_struct_type = [(np.int32, (3,)), np.int32] -_int_at_server = computation_types.FederatedType(np.int32, placements.SERVER) -_int_at_clients = computation_types.FederatedType(np.int32, placements.CLIENTS) +_int_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER +) +_int_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS +) def _make_test_struct_value(x): @@ -62,7 +65,7 @@ class ModularClippingSumFactoryComputationTest( { 'value_type_1': (np.int32, [10]), 'value_type_2': _test_struct_type, - 'value_type_3': computation_types.StructType( + 'value_type_3': federated_language.StructType( [('a', np.int32), ('b', np.int32)] ), }, @@ -71,13 +74,15 @@ class ModularClippingSumFactoryComputationTest( ) def test_type_properties_simple(self, value_type, estimate_stddev): factory = _test_factory(estimate_stddev=estimate_stddev) - process = factory.create(computation_types.to_type(value_type)) + process = factory.create(federated_language.to_type(value_type)) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. - server_state_type = computation_types.FederatedType((), placements.SERVER) + server_state_type = federated_language.FederatedType( + (), federated_language.SERVER + ) - expected_init_type = computation_types.FunctionType( + expected_init_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -88,20 +93,20 @@ def test_type_properties_simple(self, value_type, estimate_stddev): if estimate_stddev: expected_measurements_type['estimated_stddev'] = np.float32 - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), - measurements=computation_types.FederatedType( - expected_measurements_type, placements.SERVER + measurements=federated_language.FederatedType( + expected_measurements_type, federated_language.SERVER ), ), ) @@ -145,18 +150,18 @@ def test_raise_on_invalid_estimate_stddev_type(self, value): ) def test_raise_on_estimate_stddev_for_single_element(self, value_type): factory = _test_factory(estimate_stddev=True) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(ValueError, 'more than 1 element'): factory.create(value_type) @parameterized.named_parameters( - ('sequence', computation_types.SequenceType(np.int32)), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('nested_sequence', [[[computation_types.SequenceType(np.int32)]]]), + ('sequence', federated_language.SequenceType(np.int32)), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('nested_sequence', [[[federated_language.SequenceType(np.int32)]]]), ) def test_tff_value_types_raise_on(self, value_type): factory = _test_factory() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): factory.create(value_type) @@ -167,7 +172,7 @@ def test_tff_value_types_raise_on(self, value_type): ) def test_component_tensor_dtypes_raise_on(self, value_type): factory = _test_factory() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'must all be integers'): factory.create(value_type) @@ -206,7 +211,7 @@ def test_clip_individual_values( self, clip_range_lower, clip_range_upper, client_data, expected_sum ): factory = _test_factory(clip_range_lower, clip_range_upper) - value_type = computation_types.TensorType(np.int32) + value_type = federated_language.TensorType(np.int32) process = factory.create(value_type) state = process.initialize() output = process.next(state, client_data) @@ -222,7 +227,7 @@ def test_clip_sum( self, clip_range_lower, clip_range_upper, client_data, expected_sum ): factory = _test_factory(clip_range_lower, clip_range_upper) - value_type = computation_types.TensorType(np.int32) + value_type = federated_language.TensorType(np.int32) process = factory.create(value_type) state = process.initialize() output = process.next(state, client_data) @@ -238,7 +243,7 @@ def test_clip_sum_struct( self, clip_range_lower, clip_range_upper, client_data, expected_sum ): factory = _test_factory(clip_range_lower, clip_range_upper) - value_type = computation_types.to_type(_test_struct_type) + value_type = federated_language.to_type(_test_struct_type) process = factory.create(value_type) state = process.initialize() client_struct_data = [_make_test_struct_value(v) for v in client_data] diff --git a/tensorflow_federated/python/aggregators/primitives.py b/tensorflow_federated/python/aggregators/primitives.py index 27159c90a8..c34d4eb76a 100644 --- a/tensorflow_federated/python/aggregators/primitives.py +++ b/tensorflow_federated/python/aggregators/primitives.py @@ -15,6 +15,7 @@ from typing import Any, NamedTuple +import federated_language import numpy as np import tensorflow as tf @@ -22,16 +23,14 @@ from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _validate_value_on_clients(value): - py_typecheck.check_type(value, value_impl.Value) - py_typecheck.check_type(value.type_signature, computation_types.FederatedType) - if value.type_signature.placement is not placements.CLIENTS: + py_typecheck.check_type(value, federated_language.Value) + py_typecheck.check_type( + value.type_signature, federated_language.FederatedType + ) + if value.type_signature.placement is not federated_language.CLIENTS: raise TypeError( '`value` argument must be a tff.Value placed at CLIENTS. Got: {!s}' .format(value.type_signature) @@ -91,17 +90,17 @@ def _get_accumulator_type(member_type): """ def add_unknown_first_dim(tensor_type): - return computation_types.TensorType( + return federated_language.TensorType( tensor_type.dtype, (None,) + tensor_type.shape ) accumulator_type = type_conversions.structure_from_tensor_type_tree( add_unknown_first_dim, member_type ) - return computation_types.to_type( + return federated_language.to_type( _Samples( accumulators=accumulator_type, - rands=computation_types.TensorType(np.float32, shape=[None]), + rands=federated_language.TensorType(np.float32, shape=[None]), ) ) @@ -170,7 +169,9 @@ def zero_axis_concat(a, b): def report(value): return value.accumulators - return intrinsics.federated_aggregate(value, zeros, accumulate, merge, report) + return federated_language.federated_aggregate( + value, zeros, accumulate, merge, report + ) # Lower precision types are not supported to avoid potential hard to discover @@ -314,7 +315,7 @@ def _normalize_secure_quantized_sum_args( _validate_value_on_clients(client_value) client_value_member = client_value.type_signature.member if isinstance( - client_value.type_signature.member, computation_types.StructType + client_value.type_signature.member, federated_language.StructType ): dtypes = [v.dtype for v in structure.flatten(client_value_member)] for dtype in dtypes: @@ -324,33 +325,37 @@ def _normalize_secure_quantized_sum_args( _check_secure_quantized_sum_dtype(dtypes) # Validation of bounds. - if isinstance(lower_bound, value_impl.Value) != isinstance( - upper_bound, value_impl.Value + if isinstance(lower_bound, federated_language.Value) != isinstance( + upper_bound, federated_language.Value ): raise BoundsDifferentTypesError(lower_bound, upper_bound) - elif not isinstance(lower_bound, value_impl.Value): + elif not isinstance(lower_bound, federated_language.Value): # Normalization of bounds to federated values. - lower_bound = intrinsics.federated_value(lower_bound, placements.SERVER) - upper_bound = intrinsics.federated_value(upper_bound, placements.SERVER) + lower_bound = federated_language.federated_value( + lower_bound, federated_language.SERVER + ) + upper_bound = federated_language.federated_value( + upper_bound, federated_language.SERVER + ) if lower_bound.type_signature != upper_bound.type_signature: raise BoundsDifferentSignaturesError(lower_bound, upper_bound) # The remaining type checks only use lower_bound as the upper_bound has # itendical type_signature. - if lower_bound.type_signature.placement != placements.SERVER: # pytype: disable=attribute-error + if lower_bound.type_signature.placement != federated_language.SERVER: # pytype: disable=attribute-error raise BoundsNotPlacedAtServerError(lower_bound.type_signature.placement) # pytype: disable=attribute-error # Validation of client_value and bounds compatibility. bound_member = lower_bound.type_signature.member # pytype: disable=attribute-error - if isinstance(bound_member, computation_types.StructType): - if not isinstance(client_value_member, computation_types.StructType) or ( + if isinstance(bound_member, federated_language.StructType): + if not isinstance(client_value_member, federated_language.StructType) or ( structure.map_structure(lambda v: v.dtype, bound_member) != structure.map_structure(lambda v: v.dtype, client_value_member) ): raise StructuredBoundsTypeMismatchError(client_value_member, bound_member) else: # If bounds are scalar, must be compatible with all tensors in client_value. - if isinstance(client_value_member, computation_types.StructType): + if isinstance(client_value_member, federated_language.StructType): if len(set(dtypes)) > 1 or (bound_member.dtype != dtypes[0]): raise ScalarBoundStructValueDTypeError( client_value_member, bound_member @@ -630,30 +635,32 @@ def server_shift(value, lower_bnd, upper_bnd, summands): temp_box[0], ) - client_one = intrinsics.federated_value(1, placements.CLIENTS) + client_one = federated_language.federated_value(1, federated_language.CLIENTS) # Orchestration. - client_lower_bound = intrinsics.federated_broadcast(lower_bound) - client_upper_bound = intrinsics.federated_broadcast(upper_bound) + client_lower_bound = federated_language.federated_broadcast(lower_bound) + client_upper_bound = federated_language.federated_broadcast(upper_bound) - value = intrinsics.federated_map( + value = federated_language.federated_map( client_shift, (client_value, client_lower_bound, client_upper_bound) ) - num_summands = intrinsics.federated_secure_sum_bitwidth( + num_summands = federated_language.federated_secure_sum_bitwidth( client_one, bitwidth=1 ) secagg_value_type = value.type_signature.member # pytype: disable=attribute-error assert isinstance( - secagg_value_type, computation_types.TensorType - ) or isinstance(secagg_value_type, computation_types.StructType) - if isinstance(secagg_value_type, computation_types.TensorType): + secagg_value_type, federated_language.TensorType + ) or isinstance(secagg_value_type, federated_language.StructType) + if isinstance(secagg_value_type, federated_language.TensorType): bitwidths = 32 else: bitwidths = structure.map_structure(lambda t: 32, secagg_value_type) - value = intrinsics.federated_secure_sum_bitwidth(value, bitwidth=bitwidths) - value = intrinsics.federated_map( + value = federated_language.federated_secure_sum_bitwidth( + value, bitwidth=bitwidths + ) + value = federated_language.federated_map( server_shift, (value, lower_bound, upper_bound, num_summands) ) return value diff --git a/tensorflow_federated/python/aggregators/primitives_test.py b/tensorflow_federated/python/aggregators/primitives_test.py index 07ba2de074..6153152d95 100644 --- a/tensorflow_federated/python/aggregators/primitives_test.py +++ b/tensorflow_federated/python/aggregators/primitives_test.py @@ -15,16 +15,12 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import primitives from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.test import static_assert _MIN_MAX_TEST_DTYPES = [ ('int16', np.int16), @@ -40,8 +36,8 @@ class FederatedSampleTest(tf.test.TestCase): def test_federated_sample_single_value(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.CLIENTS) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -52,8 +48,8 @@ def call_federated_sample(value): def test_federated_sample_on_nested_scalars(self): tuple_type = collections.OrderedDict(x=np.float32, y=np.float32) - @federated_computation.federated_computation( - computation_types.FederatedType(tuple_type, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(tuple_type, federated_language.CLIENTS) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -78,8 +74,8 @@ def test_federated_sample_wrong_placement(self): TypeError, r'.*argument must be a tff.Value placed at CLIENTS.*' ): - @federated_computation.federated_computation( - computation_types.FederatedType(np.bool_, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.bool_, federated_language.SERVER) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -88,8 +84,8 @@ def call_federated_sample(value): def test_federated_sample_max_size_is_100(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.CLIENTS) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -100,8 +96,8 @@ def call_federated_sample(value): def test_federated_sample_preserves_nan_percentage(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.CLIENTS) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -111,8 +107,8 @@ def call_federated_sample(value): def test_federated_sample_preserves_inf_percentage(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.CLIENTS) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -121,12 +117,12 @@ def call_federated_sample(value): self.assertAlmostEqual(np.count_nonzero(np.isinf(value)), 50, delta=20) def test_federated_sample_named_tuple_type_of_ordered_dict(self): - dict_type = computation_types.to_type( + dict_type = federated_language.to_type( collections.OrderedDict([('x', np.float32), ('y', np.float32)]) ) - @federated_computation.federated_computation( - computation_types.FederatedType(dict_type, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(dict_type, federated_language.CLIENTS) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -142,7 +138,7 @@ def call_federated_sample(value): def test_federated_sample_nested_named_tuples(self): tuple_test_type = collections.OrderedDict(x=np.float32, y=np.float32) - dict_test_type = computation_types.to_type( + dict_test_type = federated_language.to_type( collections.OrderedDict(a=np.float32, b=np.float32) ) nested_tuple_type = collections.OrderedDict( @@ -150,8 +146,10 @@ def test_federated_sample_nested_named_tuples(self): ) nested_test_type = collections.namedtuple('Nested', ['tuple_1', 'tuple_2']) - @federated_computation.federated_computation( - computation_types.FederatedType(nested_tuple_type, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType( + nested_tuple_type, federated_language.CLIENTS + ) ) def call_federated_sample(value): return primitives.federated_sample(value) @@ -180,8 +178,10 @@ def test_contains_static_aggregation(self, dtype): """Tests that built computation contains at least one secure sum call.""" # Bounds provided as Python constants. - @federated_computation.federated_computation( - computation_types.FederatedType((dtype, (2,)), placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType( + (dtype, (2,)), federated_language.CLIENTS + ) ) def comp_py_bounds(value): return primitives.secure_quantized_sum( @@ -190,19 +190,25 @@ def comp_py_bounds(value): np.array(1.0, dtype), ) - static_assert.assert_not_contains_unsecure_aggregation(comp_py_bounds) + federated_language.framework.assert_not_contains_unsecure_aggregation( + comp_py_bounds + ) # Bounds provided as tff values. - @federated_computation.federated_computation( - computation_types.FederatedType((dtype, (2,)), placements.CLIENTS), - computation_types.FederatedType(dtype, placements.SERVER), - computation_types.FederatedType(dtype, placements.SERVER), + @federated_language.federated_computation( + federated_language.FederatedType( + (dtype, (2,)), federated_language.CLIENTS + ), + federated_language.FederatedType(dtype, federated_language.SERVER), + federated_language.FederatedType(dtype, federated_language.SERVER), ) def comp_tff_bounds(value, upper_bound, lower_bound): return primitives.secure_quantized_sum(value, upper_bound, lower_bound) try: - static_assert.assert_not_contains_unsecure_aggregation(comp_tff_bounds) + federated_language.framework.assert_not_contains_unsecure_aggregation( + comp_tff_bounds + ) except AssertionError: self.fail('Computation contains non-secure aggregation.') @@ -560,7 +566,7 @@ def test_scalar_int_type_py_range(self, int_type): @parameterized.named_parameters(('int32', np.int32), ('int64', np.int64)) def test_tensor_int_type_py_range(self, int_type): """Tests value of integer tensor type and scalar np range.""" - t_type = computation_types.TensorType(int_type, (2,)) + t_type = federated_language.TensorType(int_type, (2,)) call_secure_sum = _build_test_sum_fn_py_bounds( t_type, np.array(0, int_type), np.array(255, int_type) ) @@ -570,7 +576,7 @@ def test_tensor_int_type_py_range(self, int_type): @parameterized.named_parameters(('int32', np.int32), ('int64', np.int64)) def test_composite_int_type_py_range(self, int_type): """Tests value of integer composite type and scalar np range.""" - t_type = computation_types.to_type(((int_type, (2,)), (int_type, (3,)))) + t_type = federated_language.to_type(((int_type, (2,)), (int_type, (3,)))) call_secure_sum = _build_test_sum_fn_py_bounds( t_type, np.array(0, int_type), np.array(255, int_type) ) @@ -585,7 +591,7 @@ def test_composite_int_type_py_range(self, int_type): @parameterized.named_parameters(('int32', np.int32), ('int64', np.int64)) def test_composite_int_type_composite_py_range(self, int_type): """Tests value of integer composite type and composite np range.""" - t_type = computation_types.to_type(((int_type, (2,)), (int_type, (3,)))) + t_type = federated_language.to_type(((int_type, (2,)), (int_type, (3,)))) call_secure_sum = _build_test_sum_fn_py_bounds( t_type, (np.array(0, int_type), np.array(63, int_type)), @@ -612,7 +618,7 @@ def test_scalar_int_type_tff_range(self, int_type): def test_tensor_int_type_tff_range(self, int_type): """Tests value of integer tensor type and scalar tff range.""" call_secure_sum = _build_test_sum_fn_tff_bounds( - computation_types.TensorType(int_type, (2,)), int_type, int_type + federated_language.TensorType(int_type, (2,)), int_type, int_type ) self.assertAllEqual( [256, 7], call_secure_sum([[0, 0], [1, 2], [255, 5]], 0, 255) @@ -621,7 +627,7 @@ def test_tensor_int_type_tff_range(self, int_type): @parameterized.named_parameters(('int32', np.int32), ('int64', np.int64)) def test_composite_int_type_tff_range(self, int_type): """Tests value of integer composite type and scalar tff range.""" - t_type = computation_types.to_type(((int_type, (2,)), (int_type, (3,)))) + t_type = federated_language.to_type(((int_type, (2,)), (int_type, (3,)))) call_secure_sum = _build_test_sum_fn_tff_bounds(t_type, int_type, int_type) data = [((0, 0), (0, 0, 0)), ((1, 2), (3, 4, 5)), @@ -634,7 +640,7 @@ def test_composite_int_type_tff_range(self, int_type): @parameterized.named_parameters(('int32', np.int32), ('int64', np.int64)) def test_composite_int_type_composite_tff_range(self, int_type): """Tests value of integer composite type and composite tff range.""" - t_type = computation_types.to_type(((int_type, (2,)), (int_type, (3,)))) + t_type = federated_language.to_type(((int_type, (2,)), (int_type, (3,)))) call_secure_sum = _build_test_sum_fn_tff_bounds( t_type, (int_type, int_type), (int_type, int_type) ) @@ -725,7 +731,7 @@ def test_scalar_float_type_py_range(self, float_type): ) def test_tensor_float_type_py_range(self, float_type): """Tests value of float tensor type and scalar np range.""" - t_type = computation_types.TensorType(float_type, (2,)) + t_type = federated_language.TensorType(float_type, (2,)) call_secure_sum = _build_test_sum_fn_py_bounds( t_type, np.array(0.0, float_type), np.array(1.0, float_type) ) @@ -737,7 +743,9 @@ def test_tensor_float_type_py_range(self, float_type): ) def test_composite_float_type_py_range(self, float_type): """Tests value of float composite type and scalar np range.""" - t_type = computation_types.to_type(((float_type, (2,)), (float_type, (3,)))) + t_type = federated_language.to_type( + ((float_type, (2,)), (float_type, (3,))) + ) call_secure_sum = _build_test_sum_fn_py_bounds( t_type, np.array(0.0, float_type), np.array(1.0, float_type) ) @@ -754,7 +762,9 @@ def test_composite_float_type_py_range(self, float_type): ) def test_composite_float_type_composite_py_range(self, float_type): """Tests value of float composite type and composite np range.""" - t_type = computation_types.to_type(((float_type, (2,)), (float_type, (3,)))) + t_type = federated_language.to_type( + ((float_type, (2,)), (float_type, (3,))) + ) call_secure_sum = _build_test_sum_fn_py_bounds( t_type, (np.array(0.0, float_type), np.array(0.2, float_type)), @@ -786,7 +796,7 @@ def test_scalar_float_type_tff_range(self, float_type): ) def test_tensor_float_type_tff_range(self, float_type): """Tests value of float tensor type and scalar tff range.""" - t_type = computation_types.TensorType(float_type, (2,)) + t_type = federated_language.TensorType(float_type, (2,)) call_secure_sum = _build_test_sum_fn_tff_bounds( t_type, float_type, float_type ) @@ -798,7 +808,9 @@ def test_tensor_float_type_tff_range(self, float_type): ) def test_composite_float_type_tff_range(self, float_type): """Tests value of float composite type and scalar tff range.""" - t_type = computation_types.to_type(((float_type, (2,)), (float_type, (3,)))) + t_type = federated_language.to_type( + ((float_type, (2,)), (float_type, (3,))) + ) call_secure_sum = _build_test_sum_fn_tff_bounds( t_type, float_type, float_type ) @@ -815,7 +827,9 @@ def test_composite_float_type_tff_range(self, float_type): ) def test_composite_float_type_composite_tff_range(self, float_type): """Tests value of float composite type and composite tff range.""" - t_type = computation_types.to_type(((float_type, (2,)), (float_type, (3,)))) + t_type = federated_language.to_type( + ((float_type, (2,)), (float_type, (3,))) + ) call_secure_sum = _build_test_sum_fn_tff_bounds( t_type, (float_type, float_type), (float_type, float_type) ) @@ -863,7 +877,7 @@ def test_float_type_non_zero_lower_bound(self): def test_mixed_type_structure(self): """Tests a structure consisting of different dtypes can be aggregted.""" - t_type = computation_types.to_type(((np.int32, (2,)), (np.float32, (3,)))) + t_type = federated_language.to_type(((np.int32, (2,)), (np.float32, (3,)))) call_secure_sum = _build_test_sum_fn_py_bounds(t_type, (0, 0.0), (255, 1.0)) data = [((0, 0), (0.0, 0.0, 0.0)), ((1, 2), (0.3, 0.4, 0.5)), @@ -1185,11 +1199,13 @@ def test_range_type_mismatch_raises(self): def test_bounds_different_types_raises(self): with self.assertRaises(primitives.BoundsDifferentTypesError): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def _(value): - lower_bound = intrinsics.federated_value(0, placements.SERVER) + lower_bound = federated_language.federated_value( + 0, federated_language.SERVER + ) upper_bound = 1 summed_value = primitives.secure_quantized_sum( value, lower_bound, upper_bound @@ -1199,12 +1215,16 @@ def _(value): def test_clients_placed_bounds_raises(self): with self.assertRaises(primitives.BoundsNotPlacedAtServerError): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def _(value): - lower_bound = intrinsics.federated_value(0, placements.CLIENTS) - upper_bound = intrinsics.federated_value(1, placements.CLIENTS) + lower_bound = federated_language.federated_value( + 0, federated_language.CLIENTS + ) + upper_bound = federated_language.federated_value( + 1, federated_language.CLIENTS + ) summed_value = primitives.secure_quantized_sum( value, lower_bound, upper_bound ) @@ -1238,8 +1258,8 @@ def _build_test_sum_fn_py_bounds(value_type, lower_bound, upper_bound): value_type@SERVER)`. """ - @federated_computation.federated_computation( - computation_types.FederatedType(value_type, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(value_type, federated_language.CLIENTS) ) def call_secure_sum(value): summed_value = primitives.secure_quantized_sum( @@ -1269,10 +1289,14 @@ def _build_test_sum_fn_tff_bounds( lower_bound_type@SERVER, upper_bound_type@SERVER) -> value_type@SERVER)`. """ - @federated_computation.federated_computation( - computation_types.FederatedType(value_type, placements.CLIENTS), - computation_types.FederatedType(lower_bound_type, placements.SERVER), - computation_types.FederatedType(upper_bound_type, placements.SERVER), + @federated_language.federated_computation( + federated_language.FederatedType(value_type, federated_language.CLIENTS), + federated_language.FederatedType( + lower_bound_type, federated_language.SERVER + ), + federated_language.FederatedType( + upper_bound_type, federated_language.SERVER + ), ) def call_secure_sum(value, lower_bound, upper_bound): summed_value = primitives.secure_quantized_sum( diff --git a/tensorflow_federated/python/aggregators/quantile_estimation.py b/tensorflow_federated/python/aggregators/quantile_estimation.py index 9be134f481..405692e8f4 100644 --- a/tensorflow_federated/python/aggregators/quantile_estimation.py +++ b/tensorflow_federated/python/aggregators/quantile_estimation.py @@ -15,6 +15,7 @@ from typing import Optional +import federated_language import numpy as np import tensorflow_privacy as tfp @@ -23,10 +24,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import estimation_process @@ -156,24 +153,30 @@ def __init__( ) # 2. Define federated_computations. - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_zip(( - intrinsics.federated_eval(initial_state_fn, placements.SERVER), + return federated_language.federated_zip(( + federated_language.federated_eval( + initial_state_fn, federated_language.SERVER + ), quantile_agg_process.initialize(), )) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def next_fn(state, value): quantile_query_state, agg_state = state - params = intrinsics.federated_broadcast( - intrinsics.federated_map(derive_sample_params, quantile_query_state) + params = federated_language.federated_broadcast( + federated_language.federated_map( + derive_sample_params, quantile_query_state + ) ) - quantile_record = intrinsics.federated_map( + quantile_record = federated_language.federated_map( get_quantile_record, (params, value) ) @@ -181,15 +184,15 @@ def next_fn(state, value): agg_state, quantile_record ) - _, new_quantile_query_state, _ = intrinsics.federated_map( + _, new_quantile_query_state, _ = federated_language.federated_map( get_noised_result, (quantile_agg_output.result, quantile_query_state) ) - return intrinsics.federated_zip( + return federated_language.federated_zip( (new_quantile_query_state, quantile_agg_output.state) ) - report_fn = federated_computation.federated_computation( + report_fn = federated_language.federated_computation( lambda state: state[0].current_estimate, init_fn.type_signature.result ) @@ -200,9 +203,9 @@ def _affine_transform(multiplier, increment): transform_tf_comp = tensorflow_computation.tf_computation( lambda value: multiplier * value + increment, np.float32 ) - return federated_computation.federated_computation( - lambda value: intrinsics.federated_map(transform_tf_comp, value), - computation_types.FederatedType(np.float32, placements.SERVER), + return federated_language.federated_computation( + lambda value: federated_language.federated_map(transform_tf_comp, value), + federated_language.FederatedType(np.float32, federated_language.SERVER), ) diff --git a/tensorflow_federated/python/aggregators/quantile_estimation_test.py b/tensorflow_federated/python/aggregators/quantile_estimation_test.py index 43c9981d04..1d65c47196 100644 --- a/tensorflow_federated/python/aggregators/quantile_estimation_test.py +++ b/tensorflow_federated/python/aggregators/quantile_estimation_test.py @@ -15,16 +15,14 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf import tensorflow_privacy as tfp from tensorflow_federated.python.aggregators import quantile_estimation from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import estimation_process -from tensorflow_federated.python.core.test import static_assert QEProcess = quantile_estimation.PrivateQuantileEstimationProcess @@ -42,10 +40,10 @@ def test_process_type_signature(self, private): expected_num_records=100, geometric_update=True, ) - below_estimate_state = computation_types.StructType([ + below_estimate_state = federated_language.StructType([ ( 'numerator_state', - computation_types.StructType( + federated_language.StructType( [('l2_norm_clip', np.float32), ('stddev', np.float32)] ), ), @@ -59,7 +57,7 @@ def test_process_type_signature(self, private): geometric_update=True, ) below_estimate_state = () - query_state_type = computation_types.StructType([ + query_state_type = federated_language.StructType([ ('current_estimate', np.float32), ('target_quantile', np.float32), ('learning_rate', np.float32), @@ -68,14 +66,14 @@ def test_process_type_signature(self, private): process = QEProcess(quantile_estimator_query) sum_process_state_type = () - state_type = computation_types.StructType( + state_type = federated_language.StructType( [query_state_type, sum_process_state_type] ) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( state_type, - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type_signature = computation_types.FunctionType( + expected_initialize_type_signature = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -84,10 +82,10 @@ def test_process_type_signature(self, private): ) ) - estimate_type = computation_types.FederatedType( - np.float32, placements.SERVER + estimate_type = federated_language.FederatedType( + np.float32, federated_language.SERVER ) - expected_report_type_signature = computation_types.FunctionType( + expected_report_type_signature = federated_language.FunctionType( parameter=server_state_type, result=estimate_type ) self.assertTrue( @@ -96,12 +94,12 @@ def test_process_type_signature(self, private): ) ) - client_value_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_value_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) self.assertTrue( process.next.type_signature.is_equivalent_to( - computation_types.FunctionType( + federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=client_value_type ), @@ -218,7 +216,9 @@ def test_secure_estimation_true_only_contains_secure_aggregation(self): learning_rate=1.0, secure_estimation=True, ) - static_assert.assert_not_contains_unsecure_aggregation(secure_process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + secure_process.next + ) if __name__ == '__main__': diff --git a/tensorflow_federated/python/aggregators/robust.py b/tensorflow_federated/python/aggregators/robust.py index 161d0c30c7..810944af91 100644 --- a/tensorflow_federated/python/aggregators/robust.py +++ b/tensorflow_federated/python/aggregators/robust.py @@ -19,6 +19,7 @@ import typing from typing import Optional, TypeVar, Union +import federated_language import numpy as np import tensorflow as tf @@ -26,12 +27,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process @@ -43,16 +38,20 @@ def _constant_process(value): """Creates an `EstimationProcess` that reports a constant value.""" - init_fn = federated_computation.federated_computation( - lambda: intrinsics.federated_value((), placements.SERVER) + init_fn = federated_language.federated_computation( + lambda: federated_language.federated_value((), federated_language.SERVER) ) - next_fn = federated_computation.federated_computation( + next_fn = federated_language.federated_computation( lambda state, value: state, init_fn.type_signature.result, - computation_types.FederatedType(NORM_TF_TYPE, placements.CLIENTS), + federated_language.FederatedType( + NORM_TF_TYPE, federated_language.CLIENTS + ), ) - report_fn = federated_computation.federated_computation( - lambda state: intrinsics.federated_value(value, placements.SERVER), + report_fn = federated_language.federated_computation( + lambda state: federated_language.federated_value( + value, federated_language.SERVER + ), init_fn.type_signature.result, ) return estimation_process.EstimationProcess(init_fn, next_fn, report_fn) @@ -76,7 +75,7 @@ def _check_norm_process( next_parameter_type = norm_process.next.type_signature.parameter if ( - not isinstance(next_parameter_type, computation_types.StructType) + not isinstance(next_parameter_type, federated_language.StructType) or len(next_parameter_type) != 2 ): raise TypeError( @@ -84,8 +83,8 @@ def _check_norm_process( f'{next_parameter_type}' ) - norm_type_at_clients = computation_types.FederatedType( - NORM_TF_TYPE, placements.CLIENTS + norm_type_at_clients = federated_language.FederatedType( + NORM_TF_TYPE, federated_language.CLIENTS ) if not next_parameter_type[1].is_assignable_from(norm_type_at_clients): # pytype: disable=unsupported-operands raise TypeError( @@ -102,8 +101,8 @@ def _check_norm_process( ) result_type = norm_process.report.type_signature.result - norm_type_at_server = computation_types.FederatedType( - NORM_TF_TYPE, placements.SERVER + norm_type_at_server = federated_language.FederatedType( + NORM_TF_TYPE, federated_language.SERVER ) if not norm_type_at_server.is_assignable_from(result_type): raise TypeError( @@ -286,7 +285,9 @@ def _make_wrapper( clipping_norm: Union[float, estimation_process.EstimationProcess], inner_agg_factory: _T, clipped_count_sum_factory: factory.UnweightedAggregationFactory, - make_clip_fn: Callable[[factory.ValueType], computation_base.Computation], + make_clip_fn: Callable[ + [factory.ValueType], federated_language.framework.Computation + ], attribute_prefix: str, ) -> _T: """Constructs an aggregation factory that applies clip_fn before aggregation. @@ -326,7 +327,7 @@ def _make_wrapper( _check_norm_process(clipping_norm_process, 'clipping_norm_process') clipped_count_agg_process = clipped_count_sum_factory.create( - computation_types.to_type(COUNT_TF_TYPE) # pytype: disable=wrong-arg-types + federated_language.to_type(COUNT_TF_TYPE) # pytype: disable=wrong-arg-types ) prefix = lambda s: attribute_prefix + s @@ -337,16 +338,18 @@ def init_fn_impl(inner_agg_process): ('inner_agg', inner_agg_process.initialize()), (prefix('ed_count_agg'), clipped_count_agg_process.initialize()), ]) - return intrinsics.federated_zip(state) + return federated_language.federated_zip(state) def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = clipping_norm_process.report(clipping_norm_state) - clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm) + clients_clipping_norm = federated_language.federated_broadcast( + clipping_norm + ) - clipped_value, global_norm, was_clipped = intrinsics.federated_map( + clipped_value, global_norm, was_clipped = federated_language.federated_map( clip_fn, (value, clients_clipping_norm) ) @@ -375,9 +378,9 @@ def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): ]) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip(new_state), + state=federated_language.federated_zip(new_state), result=agg_output.result, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) if isinstance(inner_agg_factory, factory.WeightedAggregationFactory): @@ -395,14 +398,18 @@ def create( inner_agg_process = inner_agg_factory.create(value_type, weight_type) clip_fn = make_clip_fn(value_type) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), - computation_types.FederatedType(weight_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + weight_type, federated_language.CLIENTS + ), ) def next_fn(state, value, weight): return next_fn_impl(state, value, clip_fn, inner_agg_process, weight) @@ -423,13 +430,15 @@ def create( inner_agg_process = inner_agg_factory.create(value_type) clip_fn = make_clip_fn(value_type) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): return next_fn_impl(state, value, clip_fn, inner_agg_process) @@ -444,7 +453,7 @@ def next_fn(state, value): def _check_value_type(value_type): type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) - if not type_analysis.is_structure_of_floats(value_type): + if not federated_language.framework.is_structure_of_floats(value_type): raise TypeError( 'All values in provided value_type must be of floating ' f'dtype. Provided value_type: {value_type}' diff --git a/tensorflow_federated/python/aggregators/robust_test.py b/tensorflow_federated/python/aggregators/robust_test.py index 3de4929bbd..8e33279a24 100644 --- a/tensorflow_federated/python/aggregators/robust_test.py +++ b/tensorflow_federated/python/aggregators/robust_test.py @@ -16,6 +16,7 @@ import itertools from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -25,11 +26,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process @@ -57,31 +53,29 @@ def _zeroed_sum(clip=2.0, norm_order=2.0): return robust.zeroing_factory(clip, sum_factory.SumFactory(), norm_order) -_float_at_server = computation_types.FederatedType( - computation_types.TensorType(np.float32), placements.SERVER +_float_at_server = federated_language.FederatedType( + federated_language.TensorType(np.float32), federated_language.SERVER ) -_float_at_clients = computation_types.FederatedType( - computation_types.TensorType(np.float32), placements.CLIENTS +_float_at_clients = federated_language.FederatedType( + federated_language.TensorType(np.float32), federated_language.CLIENTS ) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _test_init_fn(): - return intrinsics.federated_value(1.0, placements.SERVER) + return federated_language.federated_value(1.0, federated_language.SERVER) -@federated_computation.federated_computation( - _float_at_server, _float_at_clients -) +@federated_language.federated_computation(_float_at_server, _float_at_clients) def _test_next_fn(state, value): del value - return intrinsics.federated_map( + return federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: x + 1.0, np.float32), state, ) -@federated_computation.federated_computation(_float_at_server) +@federated_language.federated_computation(_float_at_server) def _test_report_fn(state): return state @@ -100,17 +94,17 @@ class ClippingFactoryComputationTest(tf.test.TestCase, parameterized.TestCase): ) def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( clipping_norm=(), inner_agg=(), clipped_count_agg=() ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -119,25 +113,25 @@ def test_clip_type_properties_simple(self, value_type): ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -154,21 +148,21 @@ def test_clip_type_properties_simple(self, value_type): ) def test_clip_type_properties_weighted(self, value_type, weight_type): factory = _clipped_mean() - value_type = computation_types.to_type(value_type) - weight_type = computation_types.to_type(weight_type) + value_type = federated_language.to_type(value_type) + weight_type = federated_language.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) mean_state_type = collections.OrderedDict( value_sum_process=(), weight_sum_process=() ) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( clipping_norm=(), inner_agg=mean_state_type, clipped_count_agg=() ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -177,28 +171,28 @@ def test_clip_type_properties_weighted(self, value_type, weight_type): ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( clipping=collections.OrderedDict(mean_value=(), mean_weight=()), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), - weight=computation_types.FederatedType( - weight_type, placements.CLIENTS + weight=federated_language.FederatedType( + weight_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -213,17 +207,17 @@ def test_clip_type_properties_weighted(self, value_type, weight_type): ) def test_zero_type_properties_simple(self, value_type): factory = _zeroed_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( zeroing_norm=(), inner_agg=(), zeroed_count_agg=() ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -232,25 +226,25 @@ def test_zero_type_properties_simple(self, value_type): ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( zeroing=(), zeroing_norm=robust.NORM_TF_TYPE, zeroed_count=robust.COUNT_TF_TYPE, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -267,21 +261,21 @@ def test_zero_type_properties_simple(self, value_type): ) def test_zero_type_properties_weighted(self, value_type, weight_type): factory = _zeroed_mean() - value_type = computation_types.to_type(value_type) - weight_type = computation_types.to_type(weight_type) + value_type = federated_language.to_type(value_type) + weight_type = federated_language.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) mean_state_type = collections.OrderedDict( value_sum_process=(), weight_sum_process=() ) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( zeroing_norm=(), inner_agg=mean_state_type, zeroed_count_agg=() ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -290,28 +284,28 @@ def test_zero_type_properties_weighted(self, value_type, weight_type): ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( zeroing=collections.OrderedDict(mean_value=(), mean_weight=()), zeroing_norm=robust.NORM_TF_TYPE, zeroed_count=robust.COUNT_TF_TYPE, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), - weight=computation_types.FederatedType( - weight_type, placements.CLIENTS + weight=federated_language.FederatedType( + weight_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -331,17 +325,17 @@ def test_zero_type_properties_with_zeroed_count_agg_factory(self, value_type): norm_order=2.0, zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory(), ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( zeroing_norm=(), inner_agg=(), zeroed_count_agg=np.int32 ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -350,25 +344,25 @@ def test_zero_type_properties_with_zeroed_count_agg_factory(self, value_type): ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( zeroing=(), zeroing_norm=robust.NORM_TF_TYPE, zeroed_count=robust.COUNT_TF_TYPE, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -389,17 +383,17 @@ def test_clip_type_properties_with_clipped_count_agg_factory( inner_agg_factory=sum_factory.SumFactory(), clipped_count_sum_factory=aggregator_test_utils.SumPlusOneFactory(), ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( collections.OrderedDict( clipping_norm=(), inner_agg=(), clipped_count_agg=np.int32 ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -408,25 +402,25 @@ def test_clip_type_properties_with_clipped_count_agg_factory( ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -441,9 +435,9 @@ def test_clip_type_properties_with_clipped_count_agg_factory( ) def test_clip_preserves_aggregated_dtype_with_mixed_float(self, type_spec): factory = _clipped_sum() - mixed_float = computation_types.to_type(type_spec) + mixed_float = federated_language.to_type(type_spec) process = factory.create(mixed_float) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( mixed_float, process.next.type_signature.result.result.member ) @@ -471,23 +465,29 @@ def test_zero_preserves_aggregated_dtype_with_mixed_float( self, norm_order, type_spec ): factory = _zeroed_sum(norm_order=norm_order) - mixed_float = computation_types.to_type(type_spec) + mixed_float = federated_language.to_type(type_spec) process = factory.create(mixed_float) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( mixed_float, process.next.type_signature.result.result.member ) @parameterized.named_parameters( - ('clip_float_on_clients', 1.0, placements.CLIENTS, _clipped_mean), - ('clip_string_on_server', 'bad', placements.SERVER, _clipped_mean), - ('zero_float_on_clients', 1.0, placements.CLIENTS, _zeroed_mean), - ('zero_string_on_server', 'bad', placements.SERVER, _zeroed_mean), + ('clip_float_on_clients', 1.0, federated_language.CLIENTS, _clipped_mean), + ( + 'clip_string_on_server', + 'bad', + federated_language.SERVER, + _clipped_mean, + ), + ('zero_float_on_clients', 1.0, federated_language.CLIENTS, _zeroed_mean), + ('zero_string_on_server', 'bad', federated_language.SERVER, _zeroed_mean), ) def test_raises_on_bad_norm_process_result( self, value, placement, make_factory ): - report_fn = federated_computation.federated_computation( - lambda s: intrinsics.federated_value(value, placement), _float_at_server + report_fn = federated_language.federated_computation( + lambda s: federated_language.federated_value(value, placement), + _float_at_server, ) norm = _test_norm_process(report_fn=report_fn) @@ -499,7 +499,7 @@ def test_raises_on_bad_norm_process_result( ('zero', _zeroed_mean), ) def test_raises_on_bad_process_next_single_param(self, make_factory): - next_fn = federated_computation.federated_computation( + next_fn = federated_language.federated_computation( lambda state: state, _float_at_server ) norm = _test_norm_process(next_fn=next_fn) @@ -512,7 +512,7 @@ def test_raises_on_bad_process_next_single_param(self, make_factory): ('zero', _zeroed_mean), ) def test_raises_on_bad_process_next_three_params(self, make_factory): - next_fn = federated_computation.federated_computation( + next_fn = federated_language.federated_computation( lambda state, value1, value2: state, _float_at_server, _float_at_clients, @@ -528,10 +528,10 @@ def test_raises_on_bad_process_next_three_params(self, make_factory): ('zero', _zeroed_mean), ) def test_raises_on_bad_process_next_not_float(self, make_factory): - complex_at_clients = computation_types.FederatedType( - np.complex64, placements.CLIENTS + complex_at_clients = federated_language.FederatedType( + np.complex64, federated_language.CLIENTS ) - next_fn = federated_computation.federated_computation( + next_fn = federated_language.federated_computation( lambda state, value: state, _float_at_server, complex_at_clients ) norm = _test_norm_process(next_fn=next_fn) @@ -546,7 +546,7 @@ def test_raises_on_bad_process_next_not_float(self, make_factory): ('zero', _zeroed_mean), ) def test_raises_on_bad_process_next_two_outputs(self, make_factory): - next_fn = federated_computation.federated_computation( + next_fn = federated_language.federated_computation( lambda state, val: (state, state), _float_at_server, _float_at_clients ) norm = _test_norm_process(next_fn=next_fn) @@ -564,7 +564,7 @@ def _check_result(self, expected, result): def test_fixed_clip_sum(self): factory = _clipped_sum() - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory.create(value_type) state = process.initialize() @@ -578,8 +578,8 @@ def test_fixed_clip_sum(self): def test_fixed_clip_mean(self): factory = _clipped_mean() - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -594,7 +594,7 @@ def test_fixed_clip_mean(self): def test_fixed_clip_sum_struct(self): factory = _clipped_sum(4.0) - value_type = computation_types.to_type(_test_struct_type) + value_type = federated_language.to_type(_test_struct_type) process = factory.create(value_type) state = process.initialize() @@ -609,8 +609,8 @@ def test_fixed_clip_sum_struct(self): def test_fixed_clip_mean_struct(self): factory = _clipped_mean(4.0) - value_type = computation_types.to_type(_test_struct_type) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.to_type(_test_struct_type) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -626,7 +626,7 @@ def test_fixed_clip_mean_struct(self): def test_increasing_clip_sum(self): factory = _clipped_sum(_test_norm_process()) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory.create(value_type) state = process.initialize() @@ -650,8 +650,8 @@ def test_increasing_clip_sum(self): def test_increasing_clip_mean(self): factory = _clipped_mean(_test_norm_process()) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -675,7 +675,7 @@ def test_increasing_clip_mean(self): def test_clip_mixed_float_dtype(self): factory = _clipped_sum(clip=3.0) - mixed_float = computation_types.to_type((np.float16, np.float32)) + mixed_float = federated_language.to_type((np.float16, np.float32)) process = factory.create(mixed_float) # Should not clip anything. @@ -691,7 +691,7 @@ def test_clip_mixed_float_dtype(self): def test_fixed_zero_sum(self): factory = _zeroed_sum() - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory.create(value_type) state = process.initialize() @@ -705,8 +705,8 @@ def test_fixed_zero_sum(self): def test_fixed_zero_mean(self): factory = _zeroed_mean() - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -721,7 +721,7 @@ def test_fixed_zero_mean(self): def test_fixed_zero_sum_struct(self): factory = _zeroed_sum(4.0) - value_type = computation_types.to_type(_test_struct_type) + value_type = federated_language.to_type(_test_struct_type) process = factory.create(value_type) state = process.initialize() @@ -735,8 +735,8 @@ def test_fixed_zero_sum_struct(self): def test_fixed_zero_mean_struct(self): factory = _zeroed_mean(4.0) - value_type = computation_types.to_type(_test_struct_type) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.to_type(_test_struct_type) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -751,7 +751,7 @@ def test_fixed_zero_mean_struct(self): def test_fixed_zero_sum_struct_inf_norm(self): factory = _zeroed_sum(2.0, float('inf')) - value_type = computation_types.to_type(_test_struct_type) + value_type = federated_language.to_type(_test_struct_type) process = factory.create(value_type) state = process.initialize() @@ -765,8 +765,8 @@ def test_fixed_zero_sum_struct_inf_norm(self): def test_fixed_zero_mean_struct_inf_norm(self): factory = _zeroed_mean(2.0, float('inf')) - value_type = computation_types.to_type(_test_struct_type) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.to_type(_test_struct_type) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -781,7 +781,7 @@ def test_fixed_zero_mean_struct_inf_norm(self): def test_increasing_zero_sum(self): factory = _zeroed_sum(_test_norm_process()) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory.create(value_type) state = process.initialize() @@ -805,8 +805,8 @@ def test_increasing_zero_sum(self): def test_increasing_zero_mean(self): factory = _zeroed_mean(_test_norm_process()) - value_type = computation_types.TensorType(np.float32) - weight_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) + weight_type = federated_language.TensorType(np.float32) process = factory.create(value_type, weight_type) state = process.initialize() @@ -832,22 +832,22 @@ def test_increasing_zero_clip_sum(self): # Tests when zeroing and clipping are performed with non-integer clips. # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25. - @federated_computation.federated_computation( + @federated_language.federated_computation( _float_at_server, _float_at_clients ) def zeroing_next_fn(state, value): del value - return intrinsics.federated_map( + return federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: x + 0.75, np.float32), state, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( _float_at_server, _float_at_clients ) def clipping_next_fn(state, value): del value - return intrinsics.federated_map( + return federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: x + 0.25, np.float32), state, ) @@ -863,7 +863,7 @@ def clipping_next_fn(state, value): zeroing_norm_process, _clipped_sum(clipping_norm_process) ) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) process = factory.create(value_type) state = process.initialize() @@ -911,7 +911,7 @@ def clipping_next_fn(state, value): ) def test_zero_mixed_float_dtype(self, norm_order): factory = _zeroed_sum(clip=3.0, norm_order=norm_order) - mixed_float = computation_types.to_type((np.float16, np.float32)) + mixed_float = federated_language.to_type((np.float16, np.float32)) process = factory.create(mixed_float) # Should not zero-out anything. diff --git a/tensorflow_federated/python/aggregators/rotation.py b/tensorflow_federated/python/aggregators/rotation.py index 82b7069a73..08967290ae 100644 --- a/tensorflow_federated/python/aggregators/rotation.py +++ b/tensorflow_federated/python/aggregators/rotation.py @@ -17,6 +17,7 @@ import math from typing import Optional +import federated_language import numpy as np import tensorflow as tf @@ -26,16 +27,11 @@ from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process SEED_TF_DTYPE = tf.int64 -SEED_TFF_TYPE = computation_types.TensorType(np.int64, [2]) +SEED_TFF_TYPE = federated_language.TensorType(np.int64, [2]) OUTPUT_TF_DTYPE = np.float32 @@ -141,18 +137,20 @@ def transform(tensor, seed): _slice_and_reshape_to_template_spec, value, value_specs ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): inner_state = inner_agg_process.initialize() - my_state = intrinsics.federated_eval( + my_state = federated_language.federated_eval( tensorflow_computation.tf_computation(_init_global_seed), - placements.SERVER, + federated_language.SERVER, ) - return intrinsics.federated_zip((inner_state, my_state)) + return federated_language.federated_zip((inner_state, my_state)) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): next_fn_impl = _build_next_fn( @@ -283,18 +281,20 @@ def transform(tensor, seed): _slice_and_reshape_to_template_spec, value, value_specs ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): inner_state = inner_agg_process.initialize() - my_state = intrinsics.federated_eval( + my_state = federated_language.federated_eval( tensorflow_computation.tf_computation(_init_global_seed), - placements.SERVER, + federated_language.SERVER, ) - return intrinsics.federated_zip((inner_state, my_state)) + return federated_language.federated_zip((inner_state, my_state)) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): next_fn_impl = _build_next_fn( @@ -316,29 +316,29 @@ def _build_next_fn( def next_fn_impl(state, value): inner_state, my_state = state - client_my_state = intrinsics.federated_broadcast(my_state) - projected_value = intrinsics.federated_map( + client_my_state = federated_language.federated_broadcast(my_state) + projected_value = federated_language.federated_map( client_transform, (value, client_my_state) ) inner_agg_output = inner_agg_process.next(inner_state, projected_value) - aggregate_value = intrinsics.federated_map( + aggregate_value = federated_language.federated_map( server_transform, (inner_agg_output.result, my_state) ) new_state = ( inner_agg_output.state, - intrinsics.federated_map(update_my_state, my_state), + federated_language.federated_map(update_my_state, my_state), ) measurements = collections.OrderedDict( [(name, inner_agg_output.measurements)] ) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip(new_state), + state=federated_language.federated_zip(new_state), result=aggregate_value, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return next_fn_impl @@ -484,10 +484,10 @@ def _pad_zeros_even(x): def _check_value_type(value_type): """Check value_type meets documented criteria.""" if not ( - isinstance(value_type, computation_types.TensorType) + isinstance(value_type, federated_language.TensorType) or ( - isinstance(value_type, computation_types.StructWithPythonType) - and type_analysis.is_structure_of_tensors(value_type) + isinstance(value_type, federated_language.StructWithPythonType) + and federated_language.framework.is_structure_of_tensors(value_type) ) ): raise TypeError( @@ -497,8 +497,8 @@ def _check_value_type(value_type): ) if not ( - type_analysis.is_structure_of_floats(value_type) - or type_analysis.is_structure_of_integers(value_type) + federated_language.framework.is_structure_of_floats(value_type) + or federated_language.framework.is_structure_of_integers(value_type) ): raise TypeError( 'Component dtypes of `value_type` must be all integers or ' diff --git a/tensorflow_federated/python/aggregators/rotation_test.py b/tensorflow_federated/python/aggregators/rotation_test.py index a80483d2d9..008f81d4f3 100644 --- a/tensorflow_federated/python/aggregators/rotation_test.py +++ b/tensorflow_federated/python/aggregators/rotation_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -23,9 +24,6 @@ from tensorflow_federated.python.aggregators import rotation from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -112,40 +110,40 @@ class RotationsComputationTest(tf.test.TestCase, parameterized.TestCase): ) def test_type_properties(self, name, value_type): factory = _hadamard_sum() if name == 'hd' else _dft_sum() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.FederatedType( - ((), rotation.SEED_TFF_TYPE), placements.SERVER + server_state_type = federated_language.FederatedType( + ((), rotation.SEED_TFF_TYPE), federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type ) - expected_measurements_type = computation_types.FederatedType( - collections.OrderedDict([(name, ())]), placements.SERVER + expected_measurements_type = federated_language.FederatedType( + collections.OrderedDict([(name, ())]), federated_language.SERVER ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.next.type_signature, expected_next_type ) @@ -166,7 +164,7 @@ def test_raises_on_non_numeric_component_tensor_dtypes( self, factory_fn, value_type ): factory = factory_fn() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'all integers or all floats'): factory.create(value_type) @@ -176,28 +174,28 @@ def test_raises_on_non_numeric_component_tensor_dtypes( ( 'sequence_hadamard', _hadamard_sum, - computation_types.SequenceType(np.int32), + federated_language.SequenceType(np.int32), ), - ('sequence_dft', _dft_sum, computation_types.SequenceType(np.int32)), + ('sequence_dft', _dft_sum, federated_language.SequenceType(np.int32)), ( 'func_hadamard', _hadamard_sum, - computation_types.FunctionType(np.int32, np.int32), + federated_language.FunctionType(np.int32, np.int32), ), ( 'func_dft', _dft_sum, - computation_types.FunctionType(np.int32, np.int32), + federated_language.FunctionType(np.int32, np.int32), ), ( 'nested_sequence', _dft_sum, - [[[computation_types.SequenceType(np.int32)]]], + [[[federated_language.SequenceType(np.int32)]]], ), ) def test_raises_on_bad_tff_value_types(self, factory_fn, value_type): factory = factory_fn() - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): factory.create(value_type) @@ -262,7 +260,7 @@ class RotationsExecutionTest(tf.test.TestCase, parameterized.TestCase): def test_sum(self, value_type, client_data, expected_sum, factory_fn): """Integration test with sum for the all implementations.""" factory = factory_fn() - process = factory.create(computation_types.to_type(value_type)) + process = factory.create(federated_language.to_type(value_type)) state = process.initialize() for _ in range(3): @@ -325,7 +323,7 @@ def test_sum(self, value_type, client_data, expected_sum, factory_fn): def test_mean(self, value_type, client_data, expected_mean, factory_fn): """Integration test for the factory with mean.""" factory = factory_fn() - process = factory.create(computation_types.to_type(value_type)) + process = factory.create(federated_language.to_type(value_type)) state = process.initialize() for _ in range(3): @@ -346,7 +344,7 @@ def test_inner_aggregation_acts_on_padded_space( ): factory = _measured_hadamard_sum() if name == 'hd' else _measured_dft_sum() process = factory.create( - computation_types.to_type((np.float32, input_shape)) + federated_language.to_type((np.float32, input_shape)) ) client_input = np.ones(input_shape, np.float32) @@ -357,7 +355,7 @@ def test_inner_aggregation_acts_on_padded_space( @parameterized.named_parameters(('hd', 'hd'), ('dft', 'dft')) def test_inner_aggregation_acts_on_rotated_space(self, name): factory = _measured_hadamard_sum() if name == 'hd' else _measured_dft_sum() - process = factory.create(computation_types.TensorType(np.float32, [8])) + process = factory.create(federated_language.TensorType(np.float32, [8])) client_input = np.array( [1.0, -1.0, 2.5, -1.5, -0.5, 1.9, 2.2, -2.0], np.float32 @@ -384,7 +382,7 @@ def test_inner_aggregation_acts_on_rotated_space(self, name): def test_hd_spreads_information(self): factory = _measured_hadamard_sum() - process = factory.create(computation_types.TensorType(np.float32, [256])) + process = factory.create(federated_language.TensorType(np.float32, [256])) client_input = ( 256 * tf.one_hot(indices=17, depth=256, dtype=tf.float32).numpy() @@ -403,7 +401,7 @@ def test_hd_spreads_information(self): def test_dft_spreads_information(self): factory = _measured_dft_sum() - process = factory.create(computation_types.TensorType(np.float32, [256])) + process = factory.create(federated_language.TensorType(np.float32, [256])) client_input = ( 256 * tf.one_hot(indices=17, depth=256, dtype=tf.float32).numpy() diff --git a/tensorflow_federated/python/aggregators/sampling.py b/tensorflow_federated/python/aggregators/sampling.py index f9cf52e7c0..63ac630046 100644 --- a/tensorflow_federated/python/aggregators/sampling.py +++ b/tensorflow_federated/python/aggregators/sampling.py @@ -16,6 +16,7 @@ import collections from typing import Any, Optional +import federated_language import numpy as np import tensorflow as tf @@ -23,13 +24,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_transformations from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -40,31 +34,31 @@ def _is_tensor_or_structure_of_tensors( - value_type: computation_types.Type, + value_type: federated_language.Type, ) -> bool: """Return True if `value_type` is a TensorType or structure of TensorTypes.""" # TODO: b/181365504 - relax this to allow `StructType` once a `Struct` can be # returned from `tf.function` decorated methods. def is_tensor_or_struct_with_py_type( - type_spec: computation_types.Type, + type_spec: federated_language.Type, ) -> bool: return isinstance( type_spec, ( - computation_types.TensorType, - computation_types.StructWithPythonType, + federated_language.TensorType, + federated_language.StructWithPythonType, ), ) - return type_analysis.contains_only( + return federated_language.framework.type_contains_only( value_type, is_tensor_or_struct_with_py_type ) def build_reservoir_type( - sample_value_type: computation_types.Type, -) -> computation_types.Type: + sample_value_type: federated_language.Type, +) -> federated_language.Type: """Create the TFF type for the reservoir's state. `UnweightedReservoirSamplingFactory` will use this type as the "state" type in @@ -101,9 +95,9 @@ def build_reservoir_type( ) def add_unknown_dimension(t): - if isinstance(t, computation_types.TensorType): + if isinstance(t, federated_language.TensorType): return ( - computation_types.TensorType(dtype=t.dtype, shape=(None,) + t.shape), + federated_language.TensorType(dtype=t.dtype, shape=(None,) + t.shape), True, ) return t, False @@ -111,11 +105,11 @@ def add_unknown_dimension(t): # TODO: b/181155367 - creating a value from a type for the `zero` is a common # pattern for users of `tff.federated_aggregate` that could be made easier # for TFF users. Replace this once such helper exists. - return computation_types.to_type( + return federated_language.to_type( collections.OrderedDict( - random_seed=computation_types.TensorType(np.int64, shape=[2]), - random_values=computation_types.TensorType(np.int32, shape=[None]), - samples=type_transformations.transform_type_postorder( + random_seed=federated_language.TensorType(np.int64, shape=[2]), + random_values=federated_language.TensorType(np.int32, shape=[None]), + samples=federated_language.framework.transform_type_postorder( sample_value_type, add_unknown_dimension )[0], ) @@ -123,7 +117,7 @@ def add_unknown_dimension(t): def build_initial_sample_reservoir( - sample_value_type: computation_types.Type, seed: Optional[Any] = None + sample_value_type: federated_language.Type, seed: Optional[Any] = None ): """Build up the initial state of the reservoir for sampling. @@ -148,7 +142,7 @@ def initialize(): else: real_seed = tf.convert_to_tensor(seed, dtype=tf.int64) - def zero_for_tensor_type(t: computation_types.TensorType): + def zero_for_tensor_type(t: federated_language.TensorType): """Add an extra first dimension to create a tensor that collects samples. The first dimension will have size `0` for the algebraic zero, resulting @@ -166,7 +160,7 @@ def zero_for_tensor_type(t: computation_types.TensorType): `TypeError` if `t` is not a `tff.TensorType`. ValueError: If `t.shape` is `None`' """ - if not isinstance(t, computation_types.TensorType): + if not isinstance(t, federated_language.TensorType): raise TypeError(f'Cannot create zero for non TesnorType: {type(t)}') if t.shape is None: raise ValueError('Expected `t.shape` to not be `None`.') @@ -192,8 +186,8 @@ def zero_for_tensor_type(t: computation_types.TensorType): def _build_sample_value_computation( - value_type: computation_types.Type, sample_size: int -) -> computation_base.Computation: + value_type: federated_language.Type, sample_size: int +) -> federated_language.framework.Computation: """Builds the `accumulate` computation for sampling.""" reservoir_type = build_reservoir_type(value_type) @@ -284,8 +278,8 @@ def perform_sampling(reservoir, sample): def build_merge_samples_computation( - value_type: computation_types.Type, sample_size: int -) -> computation_base.Computation: + value_type: federated_language.Type, sample_size: int +) -> federated_language.framework.Computation: """Builds the `merge` computation for a sampling.""" reservoir_type = build_reservoir_type(value_type) @@ -334,9 +328,9 @@ def merge_samples(a, b): def _build_finalize_sample_computation( - value_type: computation_types.Type, + value_type: federated_language.Type, return_sampling_metadata: bool = False, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Builds the `report` computation for sampling.""" reservoir_type = build_reservoir_type(value_type) @@ -352,8 +346,8 @@ def finalize_samples(reservoir): def _build_check_non_finite_leaves_computation( - value_type: computation_types.Type, -) -> computation_base.Computation: + value_type: federated_language.Type, +) -> federated_language.framework.Computation: """Builds the computation for checking non-finite leaves in the client value. Args: @@ -462,27 +456,31 @@ def __init__(self, sample_size: int, return_sampling_metadata: bool = False): def create( self, - value_type: computation_types.Type, + value_type: federated_language.Type, ) -> aggregation_process.AggregationProcess: - if not type_analysis.is_structure_of_tensors(value_type): + if not federated_language.framework.is_structure_of_tensors(value_type): raise TypeError( f'`value_type` must be a structure of tensors, got a {value_type!r}.' ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): # Empty/null state, nothing is tracked across invocations. - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType(value_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(unused_state, value): # Empty tuple is the `None` of TFF. - empty_tuple = intrinsics.federated_value((), placements.SERVER) - non_finite_leaves_counts = intrinsics.federated_sum( - intrinsics.federated_map( + empty_tuple = federated_language.federated_value( + (), federated_language.SERVER + ) + non_finite_leaves_counts = federated_language.federated_sum( + federated_language.federated_map( _build_check_non_finite_leaves_computation(value_type), value ) ) @@ -496,7 +494,7 @@ def next_fn(unused_state, value): finalize_sample = _build_finalize_sample_computation( value_type, self._return_sampling_metadata ) - samples = intrinsics.federated_aggregate( + samples = federated_language.federated_aggregate( value, zero=initial_reservoir, accumulate=sample_value, diff --git a/tensorflow_federated/python/aggregators/sampling_test.py b/tensorflow_federated/python/aggregators/sampling_test.py index a92e85151c..951785afb1 100644 --- a/tensorflow_federated/python/aggregators/sampling_test.py +++ b/tensorflow_federated/python/aggregators/sampling_test.py @@ -15,26 +15,24 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import sampling from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils # Convenience type aliases. -FunctionType = computation_types.FunctionType -SequenceType = computation_types.SequenceType -StructType = computation_types.StructType -StructWithPythonType = computation_types.StructWithPythonType -TensorType = computation_types.TensorType +FunctionType = federated_language.FunctionType +SequenceType = federated_language.SequenceType +StructType = federated_language.StructType +StructWithPythonType = federated_language.StructWithPythonType +TensorType = federated_language.TensorType # Type for the random seed used in sampling is int64 tensor with shape [2]. -SEED_TYPE = computation_types.TensorType(np.int64, shape=[2]) +SEED_TYPE = federated_language.TensorType(np.int64, shape=[2]) TEST_SEED = 42 -RANDOM_VALUE_TYPE = computation_types.TensorType(np.int32, [None]) +RANDOM_VALUE_TYPE = federated_language.TensorType(np.int32, [None]) class BuildReservoirTypeTest(tf.test.TestCase): @@ -55,7 +53,7 @@ def test_scalar(self): def test_structure_of_tensors(self): self.assertEqual( sampling.build_reservoir_type( - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=TensorType(np.float32), b=[TensorType(np.int64, [2]), TensorType(np.bool_)], @@ -125,7 +123,7 @@ def test_scalar(self): ) def test_structure_of_tensors(self): - value_type = computation_types.to_type( + value_type = federated_language.to_type( collections.OrderedDict( a=TensorType(np.float32), b=[TensorType(np.int64, [2]), TensorType(np.bool_)], @@ -156,7 +154,7 @@ def test_fails_with_non_tensor_type(self): ) with self.assertRaises(TypeError): sampling.build_initial_sample_reservoir( - sample_value_type=computation_types.to_type( + sample_value_type=federated_language.to_type( collections.OrderedDict(a=SequenceType(TensorType(np.int32))) ), seed=TEST_SEED, @@ -183,7 +181,7 @@ def test_scalar_random_seed(self): ), result=reservoir_type, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( sample_computation.type_signature, expected_type ) # Get the sentinel seed so that the first call initializes based on @@ -218,7 +216,7 @@ def test_scalar_fixed_seed(self): ), result=reservoir_type, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( sample_computation.type_signature, expected_type ) reservoir = sampling.build_initial_sample_reservoir( @@ -257,7 +255,7 @@ def test_scalar_fixed_seed(self): ) def test_structure_of_tensors(self): - example_type = computation_types.to_type( + example_type = federated_language.to_type( collections.OrderedDict( a=TensorType(np.int32, [3]), b=[TensorType(np.float32), TensorType(np.bool_)], @@ -273,7 +271,7 @@ def test_structure_of_tensors(self): ), result=reservoir_type, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( sample_computation.type_signature, expected_type ) reservoir = sampling.build_initial_sample_reservoir( @@ -339,7 +337,7 @@ def test_scalar(self): parameter=collections.OrderedDict(a=reservoir_type, b=reservoir_type), result=reservoir_type, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( merge_computation.type_signature, expected_type ) reservoir_a = sampling.build_initial_sample_reservoir( @@ -399,7 +397,7 @@ def test_scalar(self): ) def test_structure_of_tensors(self): - example_type = computation_types.to_type( + example_type = federated_language.to_type( collections.OrderedDict( a=TensorType(np.int32, [3]), b=[TensorType(np.float32), TensorType(np.bool_)], @@ -413,7 +411,7 @@ def test_structure_of_tensors(self): parameter=collections.OrderedDict(a=reservoir_type, b=reservoir_type), result=reservoir_type, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( merge_computation.type_signature, expected_type ) reservoir_a = sampling.build_initial_sample_reservoir( @@ -511,7 +509,7 @@ def assertAllEqual(self, *args, **kwargs): self.assertAllClose(*args, **kwargs, atol=0.0, rtol=0.0) def test_scalar(self): - example_type = computation_types.to_type(TensorType(np.int32)) + example_type = federated_language.to_type(TensorType(np.int32)) finalize_computation = sampling._build_finalize_sample_computation( example_type ) @@ -519,7 +517,7 @@ def test_scalar(self): expected_type = FunctionType( parameter=reservoir_type, result=reservoir_type.samples ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( finalize_computation.type_signature, expected_type ) reservoir = sampling.build_initial_sample_reservoir( @@ -531,7 +529,7 @@ def test_scalar(self): self.assertAllEqual(finalize_computation(reservoir), test_samples) def test_structure(self): - example_type = computation_types.to_type( + example_type = federated_language.to_type( collections.OrderedDict( a=TensorType(np.int32), b=[ @@ -547,7 +545,7 @@ def test_structure(self): expected_type = FunctionType( parameter=reservoir_type, result=reservoir_type.samples ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( finalize_computation.type_signature, expected_type ) reservoir = sampling.build_initial_sample_reservoir( @@ -585,7 +583,7 @@ def test_scalar(self, dtype, value, is_non_finite): self.assertEqual(result, expected_result) def test_structure(self): - value_type = computation_types.to_type( + value_type = federated_language.to_type( collections.OrderedDict( a=TensorType(np.int32), b=[TensorType(np.float32, [3]), TensorType(np.bool_)], @@ -615,7 +613,7 @@ def test_fails_with_non_tensor_type(self): ) with self.assertRaisesRegex(TypeError, 'only contain `TensorType`s'): sampling._build_check_non_finite_leaves_computation( - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=TensorType(np.float32, [3]), b=[SequenceType(TensorType(np.int32)), TensorType(np.bool_)], @@ -636,10 +634,10 @@ def test_create(self, return_sampling_metadata): sample_size=10, return_sampling_metadata=return_sampling_metadata ) with self.subTest('scalar_aggregator'): - factory.create(computation_types.TensorType(np.int32)) + factory.create(federated_language.TensorType(np.int32)) with self.subTest('structure_aggregator'): factory.create( - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=TensorType(np.int32), b=[TensorType(np.float32, [3]), TensorType(np.bool_)], @@ -656,21 +654,23 @@ def test_create_fails_with_invalid_value_type(self, return_sampling_metadata): ) with self.subTest('function_type'): with self.assertRaisesRegex(TypeError, 'must be a structure of tensors'): - factory.create(computation_types.FunctionType(None, np.int32)) + factory.create(federated_language.FunctionType(None, np.int32)) with self.subTest('sequence_type'): with self.assertRaisesRegex(TypeError, 'must be a structure of tensors'): - factory.create(computation_types.SequenceType(np.int32)) + factory.create(federated_language.SequenceType(np.int32)) with self.subTest('federated_type'): with self.assertRaisesRegex(TypeError, 'must be a structure of tensors'): factory.create( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) ) @parameterized.named_parameters(('two_samples', 2), ('four_samples', 4)) def test_sample_size_limits(self, sample_size): process = sampling.UnweightedReservoirSamplingFactory( sample_size=sample_size - ).create(computation_types.TensorType(np.int32)) + ).create(federated_language.TensorType(np.int32)) state = process.initialize() output = process.next( state, @@ -690,7 +690,7 @@ def test_sample_size_limits(self, sample_size): def test_sample_size_limits_with_sampling_metadata(self, sample_size): process = sampling.UnweightedReservoirSamplingFactory( sample_size=sample_size, return_sampling_metadata=True - ).create(computation_types.TensorType(np.int32)) + ).create(federated_language.TensorType(np.int32)) state = process.initialize() output = process.next( state, @@ -710,7 +710,7 @@ def test_sample_size_limits_with_sampling_metadata(self, sample_size): def test_unfilled_reservoir(self): process = sampling.UnweightedReservoirSamplingFactory(sample_size=4).create( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ) state = process.initialize() # Create 3 client values to aggregate. @@ -730,7 +730,7 @@ def test_unfilled_reservoir(self): def test_unfilled_reservoir_with_sampling_metadata(self): process = sampling.UnweightedReservoirSamplingFactory( sample_size=4, return_sampling_metadata=True - ).create(computation_types.TensorType(np.int32)) + ).create(federated_language.TensorType(np.int32)) state = process.initialize() # Create 3 client values to aggregate. client_values = ( @@ -760,7 +760,7 @@ def test_build_factory_fails_invalid_argument(self): def test_measurements_scalar_value(self): process = sampling.UnweightedReservoirSamplingFactory(sample_size=1).create( - computation_types.TensorType(np.float32) + federated_language.TensorType(np.float32) ) state = process.initialize() output = process.next(state, [1.0, np.nan, np.inf, 2.0, 3.0]) @@ -769,7 +769,7 @@ def test_measurements_scalar_value(self): def test_measurements_structure_value(self): process = sampling.UnweightedReservoirSamplingFactory(sample_size=1).create( - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=TensorType(np.float32), b=[TensorType(np.float32, [2, 2]), TensorType(np.bool_)], diff --git a/tensorflow_federated/python/aggregators/secure.py b/tensorflow_federated/python/aggregators/secure.py index 8e0606ea20..2b187cb80c 100644 --- a/tensorflow_federated/python/aggregators/secure.py +++ b/tensorflow_federated/python/aggregators/secure.py @@ -19,6 +19,7 @@ import typing from typing import Optional, Union +import federated_language import numpy as np import tensorflow as tf @@ -28,11 +29,6 @@ from tensorflow_federated.python.core.backends.mapreduce import intrinsics as mapreduce_intrinsics from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process @@ -121,19 +117,21 @@ def __init__( def create(self, value_type): type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) - if not type_analysis.is_structure_of_integers(value_type): + if not federated_language.framework.is_structure_of_integers(value_type): raise TypeError( 'Provided value_type must either be an integer type or' f'a structure of integer types, but found: {value_type}' ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): if self._symmetric_range: @@ -147,7 +145,7 @@ def next_fn(state, value): summed_value = mapreduce_intrinsics.federated_secure_modular_sum( value, 2 * self._modulus - 1 ) - summed_value = intrinsics.federated_map( + summed_value = federated_language.federated_map( tensorflow_computation.tf_computation( self._mod_clip_after_symmetric_range_sum ), @@ -157,7 +155,9 @@ def next_fn(state, value): summed_value = mapreduce_intrinsics.federated_secure_modular_sum( value, self._modulus ) - empty_measurements = intrinsics.federated_value((), placements.SERVER) + empty_measurements = federated_language.federated_value( + (), federated_language.SERVER + ) return measured_process.MeasuredProcessOutput( state, summed_value, empty_measurements ) @@ -202,7 +202,7 @@ def _check_bound_process( next_parameter_type = bound_process.next.type_signature.parameter if ( - not isinstance(next_parameter_type, computation_types.StructType) + not isinstance(next_parameter_type, federated_language.StructType) or len(next_parameter_type) != 2 ): raise TypeError( @@ -210,8 +210,8 @@ def _check_bound_process( f'{next_parameter_type}' ) - float_type_at_clients = computation_types.FederatedType( - NORM_TYPE, placements.CLIENTS + float_type_at_clients = federated_language.FederatedType( + NORM_TYPE, federated_language.CLIENTS ) if not next_parameter_type[1].is_assignable_from(float_type_at_clients): # pytype: disable=unsupported-operands raise TypeError( @@ -228,9 +228,9 @@ def _check_bound_process( ) report_type = bound_process.report.type_signature.result - estimated_value_type_at_server = computation_types.FederatedType( + estimated_value_type_at_server = federated_language.FederatedType( next_parameter_type[1].member, # pytype: disable=unsupported-operands - placements.SERVER, + federated_language.SERVER, ) if not report_type.is_assignable_from(estimated_value_type_at_server): raise TypeError( @@ -383,14 +383,16 @@ def create( ) -> aggregation_process.AggregationProcess: self._check_value_type_compatible_with_config_mode(value_type) - @federated_computation.federated_computation( + @federated_language.federated_computation( self._init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): # Compute min and max *before* clipping and use it to update the state. - value_max = intrinsics.federated_map(_reduce_nest_max, value) - value_min = intrinsics.federated_map(_reduce_nest_min, value) + value_max = federated_language.federated_map(_reduce_nest_max, value) + value_min = federated_language.federated_map(_reduce_nest_min, value) upper_bound, lower_bound = self._get_bounds_from_state( state, value_max.type_signature.member.dtype # pytype: disable=attribute-error ) @@ -414,22 +416,22 @@ def _compute_measurements( self, upper_bound, lower_bound, value_max, value_min ): """Creates measurements to be reported. All values are summed securely.""" - is_max_clipped = intrinsics.federated_map( + is_max_clipped = federated_language.federated_map( tensorflow_computation.tf_computation( lambda bound, value: tf.cast(bound < value, COUNT_TYPE) ), - (intrinsics.federated_broadcast(upper_bound), value_max), + (federated_language.federated_broadcast(upper_bound), value_max), ) - max_clipped_count = intrinsics.federated_secure_sum_bitwidth( + max_clipped_count = federated_language.federated_secure_sum_bitwidth( is_max_clipped, bitwidth=1 ) - is_min_clipped = intrinsics.federated_map( + is_min_clipped = federated_language.federated_map( tensorflow_computation.tf_computation( lambda bound, value: tf.cast(bound > value, COUNT_TYPE) ), - (intrinsics.federated_broadcast(lower_bound), value_min), + (federated_language.federated_broadcast(lower_bound), value_min), ) - min_clipped_count = intrinsics.federated_secure_sum_bitwidth( + min_clipped_count = federated_language.federated_secure_sum_bitwidth( is_min_clipped, bitwidth=1 ) measurements = collections.OrderedDict( @@ -438,26 +440,26 @@ def _compute_measurements( secure_upper_threshold=upper_bound, secure_lower_threshold=lower_bound, ) - return intrinsics.federated_zip(measurements) + return federated_language.federated_zip(measurements) def _sum_securely(self, value, upper_bound, lower_bound): """Securely sums `value` placed at CLIENTS.""" if self._config_mode == _Config.INT: - value = intrinsics.federated_map( + value = federated_language.federated_map( _client_shift, ( value, - intrinsics.federated_broadcast(upper_bound), - intrinsics.federated_broadcast(lower_bound), + federated_language.federated_broadcast(upper_bound), + federated_language.federated_broadcast(lower_bound), ), ) - value = intrinsics.federated_secure_sum_bitwidth( + value = federated_language.federated_secure_sum_bitwidth( value, self._secagg_bitwidth ) - num_summands = intrinsics.federated_secure_sum_bitwidth( + num_summands = federated_language.federated_secure_sum_bitwidth( _client_one(), bitwidth=1 ) - value = intrinsics.federated_map( + value = federated_language.federated_map( _server_shift, (value, lower_bound, num_summands) ) return value @@ -505,14 +507,14 @@ def _check_value_type_compatible_with_config_mode(self, value_type): ) if self._config_mode == _Config.INT: - if not type_analysis.is_structure_of_integers(value_type): + if not federated_language.framework.is_structure_of_integers(value_type): raise TypeError( 'The `SecureSumFactory` was configured to work with integer ' 'dtypes. All values in provided `value_type` hence must be of ' f'integer dtype. \nProvided value_type: {value_type}' ) elif self._config_mode == _Config.FLOAT: - if not type_analysis.is_structure_of_floats(value_type): + if not federated_language.framework.is_structure_of_floats(value_type): raise TypeError( 'The `SecureSumFactory` was configured to work with floating ' 'point dtypes. All values in provided `value_type` hence must be ' @@ -580,15 +582,15 @@ def _server_shift(value, lower_bound, num_summands): ) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _empty_state(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) def _client_one(): - return intrinsics.federated_eval( + return federated_language.federated_eval( tensorflow_computation.tf_computation(lambda: tf.constant(1, tf.int32)), - placements.CLIENTS, + federated_language.CLIENTS, ) @@ -596,9 +598,10 @@ def _create_initial_state_two_processes( upper_bound_process: estimation_process.EstimationProcess, lower_bound_process: estimation_process.EstimationProcess, ): - @federated_computation.federated_computation() + + @federated_language.federated_computation() def initial_state(): - return intrinsics.federated_zip( + return federated_language.federated_zip( (upper_bound_process.initialize(), lower_bound_process.initialize()) ) @@ -616,7 +619,9 @@ def bounds_fn(): def get_bounds(state): del state # Unused. - return intrinsics.federated_eval(bounds_fn, placements.SERVER) + return federated_language.federated_eval( + bounds_fn, federated_language.SERVER + ) return get_bounds @@ -630,8 +635,10 @@ def get_bounds(state): cast_fn = tensorflow_computation.tf_computation( lambda x: tf.cast(x, bound_dtype) ) - upper_bound = intrinsics.federated_map(cast_fn, process.report(state)) - lower_bound = intrinsics.federated_map( + upper_bound = federated_language.federated_map( + cast_fn, process.report(state) + ) + lower_bound = federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: x * -1.0), upper_bound ) return upper_bound, lower_bound @@ -650,10 +657,10 @@ def get_bounds(state): cast_fn = tensorflow_computation.tf_computation( lambda x: tf.cast(x, bound_dtype) ) - upper_bound = intrinsics.federated_map( + upper_bound = federated_language.federated_map( cast_fn, upper_bound_process.report(state[0]) ) - lower_bound = intrinsics.federated_map( + lower_bound = federated_language.federated_map( cast_fn, lower_bound_process.report(state[1]) ) return upper_bound, lower_bound @@ -672,7 +679,9 @@ def update_state(state, value_min, value_max): abs_max_fn = tensorflow_computation.tf_computation( lambda x, y: tf.cast(tf.maximum(tf.abs(x), tf.abs(y)), expected_dtype) ) - abs_value_max = intrinsics.federated_map(abs_max_fn, (value_min, value_max)) + abs_value_max = federated_language.federated_map( + abs_max_fn, (value_min, value_max) + ) return process.next(state, abs_value_max) return update_state @@ -688,15 +697,15 @@ def _create_update_state_two_processes( min_dtype = lower_bound_process.next.type_signature.parameter[1].member.dtype # pytype: disable=unsupported-operands def update_state(state, value_min, value_max): - value_min = intrinsics.federated_map( + value_min = federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: tf.cast(x, min_dtype)), value_min, ) - value_max = intrinsics.federated_map( + value_max = federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: tf.cast(x, max_dtype)), value_max, ) - return intrinsics.federated_zip(( + return federated_language.federated_zip(( upper_bound_process.next(state[0], value_max), lower_bound_process.next(state[1], value_min), )) @@ -705,20 +714,20 @@ def update_state(state, value_min, value_max): def _unique_dtypes_in_structure( - type_spec: computation_types.Type, + type_spec: federated_language.Type, ) -> set[tf.dtypes.DType]: """Returns a set of unique dtypes in `type_spec`. Args: - type_spec: A `computation_types.Type`. + type_spec: A `federated_language.Type`. Returns: A `set` containing unique dtypes found in `type_spec`. """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): + py_typecheck.check_type(type_spec, federated_language.Type) + if isinstance(type_spec, federated_language.TensorType): return set([tf.dtypes.as_dtype(type_spec.dtype)]) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): return set( tf.nest.flatten( type_conversions.structure_from_tensor_type_tree( @@ -726,11 +735,11 @@ def _unique_dtypes_in_structure( ) ) ) - elif isinstance(type_spec, computation_types.FederatedType): + elif isinstance(type_spec, federated_language.FederatedType): return _unique_dtypes_in_structure(type_spec.member) else: return set() -def _is_structure_of_single_dtype(type_spec: computation_types.Type) -> bool: +def _is_structure_of_single_dtype(type_spec: federated_language.Type) -> bool: return len(_unique_dtypes_in_structure(type_spec)) == 1 diff --git a/tensorflow_federated/python/aggregators/secure_test.py b/tensorflow_federated/python/aggregators/secure_test.py index bc17e30090..5e6432bdb1 100644 --- a/tensorflow_federated/python/aggregators/secure_test.py +++ b/tensorflow_federated/python/aggregators/secure_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -22,23 +23,22 @@ from tensorflow_federated.python.aggregators import secure from tensorflow_federated.python.core.backends.test import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.core.test import static_assert -_float_at_server = computation_types.FederatedType( - np.float32, placements.SERVER +_float_at_server = federated_language.FederatedType( + np.float32, federated_language.SERVER ) -_float_at_clients = computation_types.FederatedType( - np.float32, placements.CLIENTS +_float_at_clients = federated_language.FederatedType( + np.float32, federated_language.CLIENTS +) +_int_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER +) +_int_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) -_int_at_server = computation_types.FederatedType(np.int32, placements.SERVER) -_int_at_clients = computation_types.FederatedType(np.int32, placements.CLIENTS) def _test_struct_type(dtype): @@ -46,8 +46,10 @@ def _test_struct_type(dtype): def _test_float_init_fn(factor): - return federated_computation.federated_computation( - lambda: intrinsics.federated_value(factor * 1.0, placements.SERVER) + return federated_language.federated_computation( + lambda: federated_language.federated_value( + factor * 1.0, federated_language.SERVER + ) ) @@ -56,14 +58,14 @@ def _test_float_next_fn(factor): def shift_one(x): return x + (factor * 1.0) - return federated_computation.federated_computation( - lambda state, value: intrinsics.federated_map(shift_one, state), + return federated_language.federated_computation( + lambda state, value: federated_language.federated_map(shift_one, state), _float_at_server, _float_at_clients, ) -_test_float_report_fn = federated_computation.federated_computation( +_test_float_report_fn = federated_language.federated_computation( lambda state: state, _float_at_server ) @@ -77,14 +79,14 @@ def _test_estimation_process(factor): def _measurements_type(bound_type): - return computation_types.FederatedType( + return federated_language.FederatedType( collections.OrderedDict( secure_upper_clipped_count=secure.COUNT_TYPE, secure_lower_clipped_count=secure.COUNT_TYPE, secure_upper_threshold=bound_type, secure_lower_threshold=bound_type, ), - placements.SERVER, + federated_language.SERVER, ) @@ -107,14 +109,16 @@ def test_type_properties(self, modulus, value_type, symmetric_range): modulus=modulus, symmetric_range=symmetric_range ) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - expected_state_type = computation_types.FederatedType((), placements.SERVER) + expected_state_type = federated_language.FederatedType( + (), federated_language.SERVER + ) expected_measurements_type = expected_state_type - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -123,17 +127,17 @@ def test_type_properties(self, modulus, value_type, symmetric_range): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=expected_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -141,7 +145,9 @@ def test_type_properties(self, modulus, value_type, symmetric_range): self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) def test_float_modulus_raises(self): with self.assertRaises(TypeError): @@ -160,14 +166,14 @@ def test_symmetric_range_not_bool_raises(self): secure.SecureModularSumFactory(modulus=8, symmetric_range='True') @parameterized.named_parameters( - ('float_type', computation_types.TensorType(np.float32)), - ('mixed_type', computation_types.to_type([np.float32, np.int32])), + ('float_type', federated_language.TensorType(np.float32)), + ('mixed_type', federated_language.to_type([np.float32, np.int32])), ( 'federated_type', - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_value_type_raises(self, bad_value_type): with self.assertRaises(TypeError): @@ -190,7 +196,7 @@ class SecureModularSumFactoryExecutionTest( ) def test_non_symmetric(self, modulus, client_data, expected_sum): factory_ = secure.SecureModularSumFactory(modulus, symmetric_range=False) - process = factory_.create(computation_types.TensorType(np.int32)) + process = factory_.create(federated_language.TensorType(np.int32)) state = process.initialize() output = process.next(state, client_data) self.assertEqual(expected_sum, output.result) @@ -212,7 +218,7 @@ def test_non_symmetric(self, modulus, client_data, expected_sum): ) def test_symmetric(self, modulus, client_data, expected_sum): factory_ = secure.SecureModularSumFactory(modulus, symmetric_range=True) - process = factory_.create(computation_types.TensorType(np.int32)) + process = factory_.create(federated_language.TensorType(np.int32)) state = process.initialize() output = process.next(state, client_data) self.assertEqual(expected_sum, output.result) @@ -220,7 +226,7 @@ def test_symmetric(self, modulus, client_data, expected_sum): def test_struct_type(self): factory_ = secure.SecureModularSumFactory(8) process = factory_.create( - computation_types.to_type(_test_struct_type(np.int32)) + federated_language.to_type(_test_struct_type(np.int32)) ) state = process.initialize() client_data = [ @@ -279,14 +285,16 @@ def test_type_properties_constant_bounds( upper_bound_threshold=upper_bound, lower_bound_threshold=lower_bound ) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - expected_state_type = computation_types.FederatedType((), placements.SERVER) + expected_state_type = federated_language.FederatedType( + (), federated_language.SERVER + ) expected_measurements_type = _measurements_type(measurements_dtype) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -295,17 +303,17 @@ def test_type_properties_constant_bounds( ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=expected_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -313,7 +321,9 @@ def test_type_properties_constant_bounds( self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) @parameterized.named_parameters( ('float32_scalar', np.float32, np.float32), @@ -327,17 +337,17 @@ def test_type_properties_single_bound(self, value_type, dtype): upper_bound_threshold=upper_bound_process ) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) threshold_type = upper_bound_process.report.type_signature.result.member - expected_state_type = computation_types.FederatedType( - computation_types.to_type(threshold_type), placements.SERVER + expected_state_type = federated_language.FederatedType( + federated_language.to_type(threshold_type), federated_language.SERVER ) expected_measurements_type = _measurements_type(dtype) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -346,17 +356,17 @@ def test_type_properties_single_bound(self, value_type, dtype): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=expected_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -364,7 +374,9 @@ def test_type_properties_single_bound(self, value_type, dtype): self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) @parameterized.named_parameters( ('float32_scalar', np.float32, np.float32), @@ -380,18 +392,18 @@ def test_type_properties_adaptive_bounds(self, value_type, dtype): lower_bound_threshold=lower_bound_process, ) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) threshold_type = upper_bound_process.report.type_signature.result.member - expected_state_type = computation_types.FederatedType( - computation_types.to_type((threshold_type, threshold_type)), - placements.SERVER, + expected_state_type = federated_language.FederatedType( + federated_language.to_type((threshold_type, threshold_type)), + federated_language.SERVER, ) expected_measurements_type = _measurements_type(dtype) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -400,17 +412,17 @@ def test_type_properties_adaptive_bounds(self, value_type, dtype): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=expected_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -418,7 +430,9 @@ def test_type_properties_adaptive_bounds(self, value_type, dtype): self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) @parameterized.named_parameters( ('int_smaller', -1, 1), @@ -433,7 +447,7 @@ def test_upper_bound_not_larger_than_lower_bound_raises(self, upper, lower): def test_int_ranges_beyond_2_pow_32(self): secure_sum_f = secure.SecureSumFactory(2**33, -(2**33)) # Bounds this large should be provided only with np.int64 value_type. - process = secure_sum_f.create(computation_types.TensorType(np.int64)) + process = secure_sum_f.create(federated_language.TensorType(np.int64)) self.assertEqual( process.next.type_signature.result.result.member.dtype, np.int64 ) @@ -446,7 +460,7 @@ def test_value_type_incompatible_with_config_mode_raises_int( ): secure_sum_f = secure.SecureSumFactory(upper, lower) with self.assertRaises(TypeError): - secure_sum_f.create(computation_types.TensorType(np.float32)) + secure_sum_f.create(federated_language.TensorType(np.float32)) @parameterized.named_parameters( ('py', 1.0, -1.0), ('np', np.float32(1.0), np.float32(-1.0)) @@ -456,27 +470,29 @@ def test_value_type_incompatible_with_config_mode_raises_float( ): secure_sum_f = secure.SecureSumFactory(upper, lower) with self.assertRaises(TypeError): - secure_sum_f.create(computation_types.TensorType(np.int32)) + secure_sum_f.create(federated_language.TensorType(np.int32)) def test_value_type_incompatible_with_config_mode_raises_single_process(self): secure_sum_f = secure.SecureSumFactory(_test_estimation_process(1)) with self.assertRaises(TypeError): - secure_sum_f.create(computation_types.TensorType(np.int32)) + secure_sum_f.create(federated_language.TensorType(np.int32)) def test_value_type_incompatible_with_config_mode_raises_two_processes(self): secure_sum_f = secure.SecureSumFactory( _test_estimation_process(1), _test_estimation_process(-1) ) with self.assertRaises(TypeError): - secure_sum_f.create(computation_types.TensorType(np.int32)) + secure_sum_f.create(federated_language.TensorType(np.int32)) @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_value_type_raises(self, bad_value_type): secure_sum_f = secure.SecureSumFactory(1.0, -1.0) @@ -490,7 +506,7 @@ def test_int_constant_bounds(self): secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=1, lower_bound_threshold=-1 ) - process = secure_sum_f.create(computation_types.TensorType(np.int32)) + process = secure_sum_f.create(federated_language.TensorType(np.int32)) client_data = [-2, -1, 0, 1, 2, 3] state = process.initialize() @@ -508,7 +524,7 @@ def test_float_constant_bounds(self): secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=1.0, lower_bound_threshold=-1.0 ) - process = secure_sum_f.create(computation_types.TensorType(np.float32)) + process = secure_sum_f.create(federated_language.TensorType(np.float32)) client_data = [-2.5, -0.5, 0.0, 1.0, 1.5, 2.5] state = process.initialize() @@ -526,7 +542,7 @@ def test_float_single_process_bounds(self): secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=_test_estimation_process(1) ) - process = secure_sum_f.create(computation_types.TensorType(np.float32)) + process = secure_sum_f.create(federated_language.TensorType(np.float32)) client_data = [-2.5, -0.5, 0.0, 1.0, 1.5, 3.5] state = process.initialize() @@ -565,7 +581,7 @@ def test_float_two_processes_bounds(self): upper_bound_threshold=_test_estimation_process(1), lower_bound_threshold=_test_estimation_process(-1), ) - process = secure_sum_f.create(computation_types.TensorType(np.float32)) + process = secure_sum_f.create(federated_language.TensorType(np.float32)) client_data = [-2.5, -0.5, 0.0, 1.0, 1.5, 3.5] state = process.initialize() @@ -600,7 +616,7 @@ def test_float_two_processes_bounds(self): def test_float_32_larger_than_2_pow_32(self): secure_sum_f = secure.SecureSumFactory(upper_bound_threshold=float(2**34)) - process = secure_sum_f.create(computation_types.TensorType(np.float32)) + process = secure_sum_f.create(federated_language.TensorType(np.float32)) client_data = [float(2**33), float(2**33), float(2**34)] state = process.initialize() @@ -618,7 +634,7 @@ def test_float_64_larger_than_2_pow_64(self): secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=np.float64(2**66) ) - process = secure_sum_f.create(computation_types.TensorType(np.float64)) + process = secure_sum_f.create(federated_language.TensorType(np.float64)) client_data = [ np.float64(2**65), np.float64(2**65), @@ -663,22 +679,22 @@ def _check_measurements( class IsStructureOfSingleDtypeTest(parameterized.TestCase): @parameterized.named_parameters( - ('bool', computation_types.TensorType(np.bool_)), - ('int', computation_types.TensorType(np.int32)), - ('ints', computation_types.StructType([np.int32, np.int32])), - ('floats', computation_types.StructType([np.float32, np.float32])), + ('bool', federated_language.TensorType(np.bool_)), + ('int', federated_language.TensorType(np.int32)), + ('ints', federated_language.StructType([np.int32, np.int32])), + ('floats', federated_language.StructType([np.float32, np.float32])), ( 'nested_struct', - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.StructType([np.int32, np.int32]), + federated_language.StructType([ + federated_language.TensorType(np.int32), + federated_language.StructType([np.int32, np.int32]), ]), ), ( 'federated_floats_at_clients', - computation_types.FederatedType( - computation_types.StructType([np.float32, np.float32]), - placements.CLIENTS, + federated_language.FederatedType( + federated_language.StructType([np.float32, np.float32]), + federated_language.CLIENTS, ), ), ) @@ -686,24 +702,24 @@ def test_returns_true(self, type_spec): self.assertTrue(secure._is_structure_of_single_dtype(type_spec)) @parameterized.named_parameters( - ('empty_struct', computation_types.StructType([])), - ('int_and_float', computation_types.StructType([np.int32, np.float32])), - ('int32_and_int64', computation_types.StructType([np.int32, np.int64])), + ('empty_struct', federated_language.StructType([])), + ('int_and_float', federated_language.StructType([np.int32, np.float32])), + ('int32_and_int64', federated_language.StructType([np.int32, np.int64])), ( 'float32_and_float64', - computation_types.StructType([np.float32, np.float64]), + federated_language.StructType([np.float32, np.float64]), ), ( 'nested_struct', - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.StructType([np.float32, np.float32]), + federated_language.StructType([ + federated_language.TensorType(np.int32), + federated_language.StructType([np.float32, np.float32]), ]), ), - ('sequence_of_ints', computation_types.SequenceType(np.int32)), - ('placement', computation_types.PlacementType()), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('abstract', computation_types.AbstractType('T')), + ('sequence_of_ints', federated_language.SequenceType(np.int32)), + ('placement', federated_language.PlacementType()), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('abstract', federated_language.AbstractType('T')), ) def test_returns_false(self, type_spec): self.assertFalse(secure._is_structure_of_single_dtype(type_spec)) diff --git a/tensorflow_federated/python/aggregators/stochastic_discretization.py b/tensorflow_federated/python/aggregators/stochastic_discretization.py index 49d33961c8..254d2ef288 100644 --- a/tensorflow_federated/python/aggregators/stochastic_discretization.py +++ b/tensorflow_federated/python/aggregators/stochastic_discretization.py @@ -16,6 +16,7 @@ import collections from typing import Optional +import federated_language import numpy as np import tensorflow as tf @@ -23,11 +24,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -96,11 +92,11 @@ def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: # Validate input args and value_type and parse out the TF dtypes. - if isinstance(value_type, computation_types.TensorType): + if isinstance(value_type, federated_language.TensorType): tf_dtype = value_type.dtype elif isinstance( - value_type, computation_types.StructWithPythonType - ) and type_analysis.is_structure_of_tensors(value_type): + value_type, federated_language.StructWithPythonType + ) and federated_language.framework.is_structure_of_tensors(value_type): tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type ) @@ -112,7 +108,7 @@ def create( ) # Check that all values are floats. - if not type_analysis.is_structure_of_floats(value_type): + if not federated_language.framework.is_structure_of_floats(value_type): raise TypeError( 'Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.' @@ -120,7 +116,7 @@ def create( if self._distortion_aggregation_factory is not None: distortion_aggregation_process = self._distortion_aggregation_factory.create( - computation_types.to_type(np.float32) # pytype: disable=wrong-arg-types + federated_language.to_type(np.float32) # pytype: disable=wrong-arg-types ) @tensorflow_computation.tf_computation(value_type, tf.float32) @@ -157,32 +153,36 @@ def distortion_measurement_fn(value, step_size): discretize_fn.type_signature.result ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): state = collections.OrderedDict( - step_size=intrinsics.federated_value( - self._step_size, placements.SERVER + step_size=federated_language.federated_value( + self._step_size, federated_language.SERVER ), inner_agg_process=inner_agg_process.initialize(), ) - return intrinsics.federated_zip(state) + return federated_language.federated_zip(state) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): server_step_size = state['step_size'] - client_step_size = intrinsics.federated_broadcast(server_step_size) + client_step_size = federated_language.federated_broadcast( + server_step_size + ) - discretized_value = intrinsics.federated_map( + discretized_value = federated_language.federated_map( discretize_fn, (value, client_step_size) ) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) - undiscretized_agg_value = intrinsics.federated_map( + undiscretized_agg_value = federated_language.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size) ) @@ -194,7 +194,7 @@ def next_fn(state, value): ) if self._distortion_aggregation_factory is not None: - distortions = intrinsics.federated_map( + distortions = federated_language.federated_map( distortion_measurement_fn, (value, client_step_size) ) aggregate_distortion = distortion_aggregation_process.next( @@ -203,9 +203,9 @@ def next_fn(state, value): measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip(new_state), + state=federated_language.federated_zip(new_state), result=undiscretized_agg_value, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return aggregation_process.AggregationProcess(init_fn, next_fn) diff --git a/tensorflow_federated/python/aggregators/stochastic_discretization_test.py b/tensorflow_federated/python/aggregators/stochastic_discretization_test.py index 54db38bc9d..59e69b1f18 100644 --- a/tensorflow_federated/python/aggregators/stochastic_discretization_test.py +++ b/tensorflow_federated/python/aggregators/stochastic_discretization_test.py @@ -16,6 +16,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -25,10 +26,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -61,7 +58,8 @@ def _named_test_cases_product(*args): _measurement_aggregator = measurements.add_measurements( - sum_factory.SumFactory(), client_measurement_fn=intrinsics.federated_sum + sum_factory.SumFactory(), + client_measurement_fn=federated_language.federated_sum, ) @@ -81,49 +79,49 @@ def test_type_properties(self, value_type): inner_agg_factory=_measurement_aggregator, distortion_aggregation_factory=mean.UnweightedMeanFactory(), ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) quantize_type = type_conversions.structure_from_tensor_type_tree( lambda x: (np.int32, x.shape), value_type ) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - server_state_type = computation_types.StructType( + server_state_type = federated_language.StructType( [('step_size', np.float32), ('inner_agg_process', ())] ) - server_state_type = computation_types.FederatedType( - server_state_type, placements.SERVER + server_state_type = federated_language.FederatedType( + server_state_type, federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type ) - expected_measurements_type = computation_types.StructType([ + expected_measurements_type = federated_language.StructType([ ('stochastic_discretization', quantize_type), ('distortion', np.float32), ]) - expected_measurements_type = computation_types.FederatedType( - expected_measurements_type, placements.SERVER + expected_measurements_type = federated_language.FederatedType( + expected_measurements_type, federated_language.SERVER ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - value_type, placements.SERVER + result=federated_language.FederatedType( + value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( process.next.type_signature, expected_next_type ) @@ -138,20 +136,20 @@ def test_raises_on_bad_component_tensor_dtypes(self, value_type): factory = stochastic_discretization.StochasticDiscretizationFactory( inner_agg_factory=_measurement_aggregator, step_size=0.1 ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) self.assertRaises(TypeError, factory.create, value_type) @parameterized.named_parameters( ('plain_struct', [('a', np.int32)]), - ('sequence', computation_types.SequenceType(np.int32)), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('nested_sequence', [[[computation_types.SequenceType(np.int32)]]]), + ('sequence', federated_language.SequenceType(np.int32)), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('nested_sequence', [[[federated_language.SequenceType(np.int32)]]]), ) def test_raises_on_bad_tff_value_types(self, value_type): factory = stochastic_discretization.StochasticDiscretizationFactory( inner_agg_factory=_measurement_aggregator, step_size=0.1 ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) self.assertRaises(TypeError, factory.create, value_type) @@ -198,7 +196,7 @@ def test_discretize_impl(self, value_type, client_values, expected_sum): step_size=0.125, distortion_aggregation_factory=mean.UnweightedMeanFactory(), ) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = factory.create(value_type) state = process.initialize() diff --git a/tensorflow_federated/python/aggregators/sum_factory.py b/tensorflow_federated/python/aggregators/sum_factory.py index 8ae9189a44..1e8c21acf6 100644 --- a/tensorflow_federated/python/aggregators/sum_factory.py +++ b/tensorflow_federated/python/aggregators/sum_factory.py @@ -15,12 +15,10 @@ import typing +import federated_language + from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -41,17 +39,21 @@ def create( type_args = typing.get_args(factory.ValueType) py_typecheck.check_type(value_type, type_args) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): - summed_value = intrinsics.federated_sum(value) - empty_measurements = intrinsics.federated_value((), placements.SERVER) + summed_value = federated_language.federated_sum(value) + empty_measurements = federated_language.federated_value( + (), federated_language.SERVER + ) return measured_process.MeasuredProcessOutput( state, summed_value, empty_measurements ) diff --git a/tensorflow_federated/python/aggregators/sum_factory_test.py b/tensorflow_federated/python/aggregators/sum_factory_test.py index 38b07225d9..eb0401a471 100644 --- a/tensorflow_federated/python/aggregators/sum_factory_test.py +++ b/tensorflow_federated/python/aggregators/sum_factory_test.py @@ -15,14 +15,13 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -35,22 +34,24 @@ class SumFactoryComputationTest(tf.test.TestCase, parameterized.TestCase): def test_type_properties(self, value_type): sum_f = sum_factory.SumFactory() self.assertIsInstance(sum_f, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) process = sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType((), placements.SERVER) - expected_measurements_type = computation_types.FederatedType( - (), placements.SERVER + expected_state_type = federated_language.FederatedType( + (), federated_language.SERVER + ) + expected_measurements_type = federated_language.FederatedType( + (), federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -59,7 +60,7 @@ def test_type_properties(self, value_type): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type ), @@ -74,10 +75,12 @@ def test_type_properties(self, value_type): @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_value_type_raises(self, bad_value_type): sum_f = sum_factory.SumFactory() @@ -89,7 +92,7 @@ class SumFactoryExecutionTest(tf.test.TestCase): def test_sum_scalar(self): sum_f = sum_factory.SumFactory() - value_type = computation_types.to_type(np.float32) + value_type = federated_language.to_type(np.float32) process = sum_f.create(value_type) state = process.initialize() @@ -103,7 +106,7 @@ def test_sum_scalar(self): def test_sum_structure(self): sum_f = sum_factory.SumFactory() - value_type = computation_types.to_type(((np.float32, (2,)), np.int32)) + value_type = federated_language.to_type(((np.float32, (2,)), np.int32)) process = sum_f.create(value_type) state = process.initialize() diff --git a/tensorflow_federated/python/analytics/BUILD b/tensorflow_federated/python/analytics/BUILD index cefb441c82..bf970ccee9 100644 --- a/tensorflow_federated/python/analytics/BUILD +++ b/tensorflow_federated/python/analytics/BUILD @@ -81,11 +81,7 @@ py_library( srcs = ["count_distinct.py"], deps = [ "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -101,6 +97,6 @@ py_test( deps = [ ":count_distinct", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/analytics/count_distinct.py b/tensorflow_federated/python/analytics/count_distinct.py index 4728fdfc1c..ea39848a21 100644 --- a/tensorflow_federated/python/analytics/count_distinct.py +++ b/tensorflow_federated/python/analytics/count_distinct.py @@ -17,15 +17,11 @@ algorithm. """ +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements # See https://en.wikipedia.org/wiki/HyperLogLog for usage of these constants. # Setting HLL_SKETCH_SIZE = 64 is not currently supported because it is not @@ -53,7 +49,9 @@ def _log2(u: tf.Tensor) -> tf.Tensor: return ans - 1 -def build_client_hyperloglog_computation() -> computation_base.Computation: +def build_client_hyperloglog_computation() -> ( + federated_language.framework.Computation +): """Builds a `tff.Computation` for computing client hyperloglog sketches. Specifically, the returned computation consumes a dataset of integer hashes @@ -64,7 +62,7 @@ def build_client_hyperloglog_computation() -> computation_base.Computation: """ @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.int64) + federated_language.SequenceType(np.int64) ) @tf.function def _client_hyperloglog(client_data: tf.data.Dataset) -> tf.Tensor: @@ -90,7 +88,9 @@ def reduce_func(state, hash_value): return _client_hyperloglog -def build_federated_secure_max_computation() -> computation_base.Computation: +def build_federated_secure_max_computation() -> ( + federated_language.framework.Computation +): """Builds a `tff.Computation` for computing max in a secure fashion. Specifically, the returned computation consumes sketches at @CLIENTS and @@ -108,10 +108,10 @@ def build_federated_secure_max_computation() -> computation_base.Computation: A `tff.Computation` for computing max of client vectors. """ - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.TensorType(np.int64, shape=[HLL_SKETCH_SIZE]), - placements.CLIENTS, + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.TensorType(np.int64, shape=[HLL_SKETCH_SIZE]), + federated_language.CLIENTS, ) ) def federated_secure_max(sketch): @@ -151,9 +151,9 @@ def maxes_from_onehots(x): ) return tf.reduce_max(tf.cast(x > 0, tf.int64) * mult, axis=1) - onehot_sketches = intrinsics.federated_map(_onehot_sketch, sketch) - server_sketch = intrinsics.federated_secure_sum(onehot_sketches, 1) - maxes = intrinsics.federated_map(maxes_from_onehots, server_sketch) + onehot_sketches = federated_language.federated_map(_onehot_sketch, sketch) + server_sketch = federated_language.federated_secure_sum(onehot_sketches, 1) + maxes = federated_language.federated_map(maxes_from_onehots, server_sketch) return maxes return federated_secure_max @@ -161,7 +161,7 @@ def maxes_from_onehots(x): def create_federated_hyperloglog_computation( *, use_secagg: bool = False -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates a `tff.Computation` to estimate the number of distinct strings. The returned computation consumes data @CLIENTS and produces an estimate of @@ -205,19 +205,23 @@ def estimate_count_from_sketch(sketch: tf.Tensor) -> tf.int64: client_hyperloglog = build_client_hyperloglog_computation() federated_secure_max = build_federated_secure_max_computation() - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.SequenceType(np.str_), placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.SequenceType(np.str_), federated_language.CLIENTS ) ) def federated_hyperloglog(client_data): - client_hash = intrinsics.federated_map(hash_client_data, client_data) - sketches = intrinsics.federated_map(client_hyperloglog, client_hash) + client_hash = federated_language.federated_map( + hash_client_data, client_data + ) + sketches = federated_language.federated_map(client_hyperloglog, client_hash) if use_secagg: server_sketch = federated_secure_max(sketches) else: - server_sketch = intrinsics.federated_max(sketches) + server_sketch = federated_language.federated_max(sketches) - return intrinsics.federated_map(estimate_count_from_sketch, server_sketch) + return federated_language.federated_map( + estimate_count_from_sketch, server_sketch + ) return federated_hyperloglog diff --git a/tensorflow_federated/python/analytics/count_distinct_test.py b/tensorflow_federated/python/analytics/count_distinct_test.py index 0d8c1df9be..00463f8350 100644 --- a/tensorflow_federated/python/analytics/count_distinct_test.py +++ b/tensorflow_federated/python/analytics/count_distinct_test.py @@ -14,12 +14,12 @@ """Tests for count_distinct.py.""" from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.analytics import count_distinct from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types def hll_sketch_python(values): @@ -69,9 +69,11 @@ def test_type_properties(self): client_hyperloglog = count_distinct.build_client_hyperloglog_computation() - expected_type_signature = computation_types.FunctionType( - computation_types.SequenceType(computation_types.TensorType(np.int64)), - computation_types.TensorType( + expected_type_signature = federated_language.FunctionType( + federated_language.SequenceType( + federated_language.TensorType(np.int64) + ), + federated_language.TensorType( np.int64, [count_distinct.HLL_SKETCH_SIZE] ), ) diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD b/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD index 52a17dccb0..69f31f96b4 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD @@ -40,11 +40,8 @@ py_library( "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/analytics:data_processing", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", + "@federated_language//federated_language", ], ) @@ -55,7 +52,7 @@ py_test( ":iblt_clipping", ":iblt_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -111,11 +108,7 @@ py_library( "//tensorflow_federated/python/analytics:data_processing", "//tensorflow_federated/python/core/backends/mapreduce:intrinsics", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -127,10 +120,7 @@ py_test( deps = [ ":iblt_tff", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -162,12 +152,9 @@ py_library( "//tensorflow_federated/python/aggregators:sum_factory", "//tensorflow_federated/python/analytics:data_processing", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -181,7 +168,7 @@ py_test( "//tensorflow_federated/python/aggregators:secure", "//tensorflow_federated/python/aggregators:sum_factory", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping.py index 357fd239d2..531cfa4fcd 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping.py @@ -16,6 +16,7 @@ import collections from typing import Optional +import federated_language import numpy as np import tensorflow as tf @@ -23,10 +24,6 @@ from tensorflow_federated.python.analytics import data_processing from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process @@ -161,8 +158,8 @@ def __init__( def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: - expected_type = computation_types.SequenceType( - computation_types.TensorType(shape=[None], dtype=np.str_) + expected_type = federated_language.SequenceType( + federated_language.TensorType(shape=[None], dtype=np.str_) ) if value_type != expected_type: @@ -184,12 +181,14 @@ def preprocess(client_data): inner_process = self.inner_iblt_agg.create(preprocess.type_signature.result) - @federated_computation.federated_computation( + @federated_language.federated_computation( inner_process.initialize.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, client_data): - preprocessed = intrinsics.federated_map(preprocess, client_data) + preprocessed = federated_language.federated_map(preprocess, client_data) return inner_process.next(state, preprocessed) return aggregation_process.AggregationProcess( diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping_test.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping_test.py index a344ddcf98..e977890661 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping_test.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_clipping_test.py @@ -13,13 +13,13 @@ # limitations under the License. from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_clipping from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types class IbltClippingTest(parameterized.TestCase): @@ -109,8 +109,8 @@ def test_incorrect_value_type(self): capacity=100, string_max_bytes=10, repetitions=3, seed=0 ) clip_fac = iblt_clipping.ClippingIbltFactory(iblt_fac) - wrong_type = computation_types.SequenceType( - computation_types.TensorType(shape=[None], dtype=np.int32) + wrong_type = federated_language.SequenceType( + federated_language.TensorType(shape=[None], dtype=np.int32) ) with self.assertRaises(ValueError): clip_fac.create(wrong_type) @@ -122,8 +122,8 @@ def test_clipping_factory(self): clip_fac = iblt_clipping.ClippingIbltFactory( iblt_fac, max_words_per_user=4, unique_counts=True ) - value_type = computation_types.SequenceType( - computation_types.TensorType(shape=[None], dtype=np.str_) + value_type = federated_language.SequenceType( + federated_language.TensorType(shape=[None], dtype=np.str_) ) agg_process = clip_fac.create(value_type) data = [ diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory.py index 1d7255c9bf..ebbeee2a04 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory.py @@ -17,6 +17,7 @@ from typing import Any, Optional import attrs +import federated_language import numpy as np import tensorflow as tf @@ -26,10 +27,6 @@ from tensorflow_federated.python.analytics.heavy_hitters.iblt import chunkers from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_tensor from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -159,7 +156,7 @@ def __init__( self._seed = seed def create( - self, value_type: computation_types.SequenceType + self, value_type: federated_language.SequenceType ) -> aggregation_process.AggregationProcess: # pytype: disable=signature-mismatch """Creates an AggregationProcess using IBLT to aggregate strings. @@ -176,12 +173,12 @@ def create( A `tff.templates.AggregationProcess` to aggregate strings and values associate with the strings. """ - expected_value_type = computation_types.SequenceType( + expected_value_type = federated_language.SequenceType( collections.OrderedDict([ (DATASET_KEY, np.str_), ( DATASET_VALUE, - computation_types.TensorType(shape=[None], dtype=np.int64), + federated_language.TensorType(shape=[None], dtype=np.int64), ), ]) ) @@ -235,19 +232,25 @@ def decode_iblt(sketch, value_tensor): encode_iblt.type_signature.result[1] ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): sketch_state = inner_aggregator_sketch.initialize() value_tensor_state = inner_aggregator_value_tensor.initialize() - return intrinsics.federated_zip((sketch_state, value_tensor_state)) + return federated_language.federated_zip( + (sketch_state, value_tensor_state) + ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, dataset): sketch_state, value_tensor_state = state - sketch, value_tensor = intrinsics.federated_map(encode_iblt, dataset) + sketch, value_tensor = federated_language.federated_map( + encode_iblt, dataset + ) sketch_output = inner_aggregator_sketch.next(sketch_state, sketch) value_tensor_output = inner_aggregator_value_tensor.next( value_tensor_state, value_tensor @@ -255,21 +258,21 @@ def next_fn(state, dataset): summed_sketch = sketch_output.result summed_value_tensor = value_tensor_output.result (output_strings, string_values, num_not_decoded) = ( - intrinsics.federated_map( + federated_language.federated_map( decode_iblt, (summed_sketch, summed_value_tensor) ) ) - result = intrinsics.federated_zip( + result = federated_language.federated_zip( ServerOutput( output_strings=output_strings, string_values=string_values, num_not_decoded=num_not_decoded, ) ) - updated_state = intrinsics.federated_zip( + updated_state = federated_language.federated_zip( (sketch_output.state, value_tensor_output.state) ) - updated_measurements = intrinsics.federated_zip( + updated_measurements = federated_language.federated_zip( collections.OrderedDict( num_not_decoded=num_not_decoded, sketch=sketch_output.measurements, diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory_test.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory_test.py index 08d1b21f9c..9f1a5cfcf5 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory_test.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_factory_test.py @@ -17,6 +17,7 @@ from absl import logging from absl.testing import parameterized +import federated_language import grpc import numpy as np import tensorflow as tf @@ -27,7 +28,6 @@ from tensorflow_federated.python.analytics.heavy_hitters.iblt import chunkers from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types # Convenience Aliases _CharacterEncoding = chunkers.CharacterEncoding @@ -44,10 +44,10 @@ ), ] -VALUE_TYPE = computation_types.SequenceType( +VALUE_TYPE = federated_language.SequenceType( collections.OrderedDict( key=np.str_, - value=computation_types.TensorType(shape=(3,), dtype=np.int64), + value=federated_language.TensorType(shape=(3,), dtype=np.int64), ) ) @@ -138,61 +138,61 @@ def test_repetitions_validation(self): @parameterized.named_parameters( ( 'scalar', - computation_types.SequenceType( - computation_types.TensorType(shape=(), dtype=np.int64) + federated_language.SequenceType( + federated_language.TensorType(shape=(), dtype=np.int64) ), ), ( 'list', - computation_types.SequenceType( - computation_types.TensorType(shape=(3,), dtype=np.int64) + federated_language.SequenceType( + federated_language.TensorType(shape=(3,), dtype=np.int64) ), ), ( 'dict_wrong_key', - computation_types.SequenceType( + federated_language.SequenceType( collections.OrderedDict([ ('foo', np.int64), ( iblt_factory.DATASET_VALUE, - computation_types.TensorType(shape=(1,), dtype=np.int64), + federated_language.TensorType(shape=(1,), dtype=np.int64), ), ]) ), ), ( 'dict_extra_key', - computation_types.SequenceType( + federated_language.SequenceType( collections.OrderedDict([ ('bar', np.int64), (iblt_factory.DATASET_KEY, np.int64), ( iblt_factory.DATASET_VALUE, - computation_types.TensorType(shape=(1,), dtype=np.int64), + federated_language.TensorType(shape=(1,), dtype=np.int64), ), ]) ), ), ( 'dict_int64_int64', - computation_types.SequenceType( + federated_language.SequenceType( collections.OrderedDict([ (iblt_factory.DATASET_KEY, np.int64), ( iblt_factory.DATASET_VALUE, - computation_types.TensorType(shape=(1,), dtype=np.int64), + federated_language.TensorType(shape=(1,), dtype=np.int64), ), ]) ), ), ( 'dict_string_int32', - computation_types.SequenceType( + federated_language.SequenceType( collections.OrderedDict([ (iblt_factory.DATASET_KEY, np.str_), ( iblt_factory.DATASET_VALUE, - computation_types.TensorType(shape=(1,), dtype=np.int32), + federated_language.TensorType(shape=(1,), dtype=np.int32), ), ]) ), @@ -213,10 +213,10 @@ def test_string_max_bytes_error(self): ), (iblt_factory.DATASET_VALUE, tf.constant([[1]], dtype=np.int64)), ]) - value_type = computation_types.SequenceType( + value_type = federated_language.SequenceType( collections.OrderedDict( key=np.str_, - value=computation_types.TensorType(shape=(1,), dtype=np.int64), + value=federated_language.TensorType(shape=(1,), dtype=np.int64), ) ) client_data = [tf.data.Dataset.from_tensor_slices(client)] @@ -322,10 +322,10 @@ def test_binary_string_aggregation(self): string_max_bytes=5, ) iblt_agg_process = iblt_agg_factory.create( - value_type=computation_types.SequenceType( + value_type=federated_language.SequenceType( collections.OrderedDict( key=np.str_, - value=computation_types.TensorType( + value=federated_language.TensorType( shape=(2,), dtype=np.int64, ), diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff.py index 4e3e843383..536b80c8bf 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff.py @@ -17,6 +17,7 @@ from typing import Any, Optional import attrs +import federated_language import numpy as np import tensorflow as tf @@ -26,11 +27,6 @@ from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_tensor from tensorflow_federated.python.core.backends.mapreduce import intrinsics as mapreduce_intrinsics from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements # Convenience Aliases _CharacterEncoding = chunkers.CharacterEncoding @@ -76,7 +72,7 @@ def build_iblt_computation( seed: int = 0, batch_size: int = 1, repetitions: int = 3, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Builds the `tff.Computation` for heavy-hitters discovery with IBLT. Args: @@ -178,8 +174,8 @@ def build_iblt_computation( if decode_iblt_fn is None: decode_iblt_fn = iblt_tensor.decode_iblt_tensor_tf - dataset_type = computation_types.SequenceType( - computation_types.TensorType(shape=[None], dtype=np.str_) + dataset_type = federated_language.SequenceType( + federated_language.TensorType(shape=[None], dtype=np.str_) ) @tensorflow_computation.tf_computation(dataset_type) @@ -276,7 +272,7 @@ def decode_heavy_hitters(sketch, count_tensor): ) def secure_sum(x): - return intrinsics.federated_secure_sum( + return federated_language.federated_secure_sum( x, max_input=2**secure_sum_bitwidth - 1 ) @@ -285,8 +281,8 @@ def secure_modular_sum(x): x, modulus=np.int64(iblt_lib.DEFAULT_FIELD_SIZE) ) - @federated_computation.federated_computation( - computation_types.FederatedType(dataset_type, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(dataset_type, federated_language.CLIENTS) ) def one_round_computation(examples): """The TFF computation to compute the aggregated IBLT sketch.""" @@ -296,16 +292,20 @@ def one_round_computation(examples): sketch_sum_fn = secure_modular_sum count_sum_fn = secure_sum else: - sketch_sum_fn = intrinsics.federated_sum - count_sum_fn = intrinsics.federated_sum - round_timestamp = intrinsics.federated_eval( + sketch_sum_fn = federated_language.federated_sum + count_sum_fn = federated_language.federated_sum + round_timestamp = federated_language.federated_eval( tensorflow_computation.tf_computation( lambda: tf.cast(tf.timestamp(), np.int64) ), - placements.SERVER, + federated_language.SERVER, + ) + clients = count_sum_fn( + federated_language.federated_value(1, federated_language.CLIENTS) + ) + sketch, count_tensor = federated_language.federated_map( + compute_sketch, examples ) - clients = count_sum_fn(intrinsics.federated_value(1, placements.CLIENTS)) - sketch, count_tensor = intrinsics.federated_map(compute_sketch, examples) sketch = sketch_sum_fn(sketch) count_tensor = count_sum_fn(count_tensor) @@ -314,8 +314,10 @@ def one_round_computation(examples): heavy_hitters_unique_counts, heavy_hitters_counts, num_not_decoded, - ) = intrinsics.federated_map(decode_heavy_hitters, (sketch, count_tensor)) - server_output = intrinsics.federated_zip( + ) = federated_language.federated_map( + decode_heavy_hitters, (sketch, count_tensor) + ) + server_output = federated_language.federated_zip( ServerOutput( clients=clients, heavy_hitters=heavy_hitters, diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff_test.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff_test.py index 3d6c17bf12..64aac3bf14 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff_test.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_tff_test.py @@ -19,15 +19,12 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_tff from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils DATA = [ ['hello', 'hey', 'hi', 'hi', 'hi', '新年快乐'], @@ -162,32 +159,34 @@ class IbltTffConstructionTest(absltest.TestCase): def test_default_construction(self): iblt_computation = iblt_tff.build_iblt_computation() - self.assertIsInstance(iblt_computation, computation_base.Computation) - type_test_utils.assert_types_identical( + self.assertIsInstance( + iblt_computation, federated_language.framework.Computation + ) + federated_language.framework.assert_types_identical( iblt_computation.type_signature, - computation_types.FunctionType( - parameter=computation_types.FederatedType( - computation_types.SequenceType( - computation_types.TensorType(shape=[None], dtype=np.str_) + federated_language.FunctionType( + parameter=federated_language.FederatedType( + federated_language.SequenceType( + federated_language.TensorType(shape=[None], dtype=np.str_) ), - placements.CLIENTS, + federated_language.CLIENTS, ), - result=computation_types.FederatedType( + result=federated_language.FederatedType( iblt_tff.ServerOutput( clients=np.int32, - heavy_hitters=computation_types.TensorType( + heavy_hitters=federated_language.TensorType( shape=[None], dtype=np.str_ ), - heavy_hitters_unique_counts=computation_types.TensorType( + heavy_hitters_unique_counts=federated_language.TensorType( shape=[None], dtype=np.int64 ), - heavy_hitters_counts=computation_types.TensorType( + heavy_hitters_counts=federated_language.TensorType( shape=[None], dtype=np.int64 ), num_not_decoded=np.int64, round_timestamp=np.int64, ), - placements.SERVER, + federated_language.SERVER, ), ), ) diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/BUILD b/tensorflow_federated/python/analytics/hierarchical_histogram/BUILD index 36be81313c..dbd67661a2 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/BUILD +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/BUILD @@ -46,10 +46,9 @@ py_test( "//tensorflow_federated/python/aggregators:differential_privacy", "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -60,11 +59,8 @@ py_library( ":clipping_factory", ":hierarchical_histogram_factory", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", + "@federated_language//federated_language", ], ) @@ -75,7 +71,7 @@ py_test( deps = [ ":hierarchical_histogram_lib", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -112,12 +108,8 @@ py_library( "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/aggregators:sum_factory", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", + "@federated_language//federated_language", ], ) @@ -128,9 +120,8 @@ py_test( ":clipping_factory", "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory.py b/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory.py index c08cc7dd9d..fad2dc1683 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory.py +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory.py @@ -15,17 +15,13 @@ from typing import Optional +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process # Supported clip mechanisms. @@ -115,7 +111,7 @@ def create(self, value_type): inner_value_type = value_type if self._cast_to_float: - inner_value_type = computation_types.TensorType( + inner_value_type = federated_language.TensorType( np.float32, value_type.shape ) inner_agg_process = self._inner_agg_factory.create(inner_value_type) @@ -127,22 +123,26 @@ def create(self, value_type): lambda x: tf.cast(x, inner_value_type.dtype) # pytype: disable=attribute-error ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): # Clip values before aggregation. - clipped_value = intrinsics.federated_map( + clipped_value = federated_language.federated_map( tff_clip_fn, ( value, - intrinsics.federated_value( - self._max_records_per_user, placements.CLIENTS + federated_language.federated_value( + self._max_records_per_user, federated_language.CLIENTS ), ), ) - clipped_value = intrinsics.federated_map(tff_cast_fn, clipped_value) + clipped_value = federated_language.federated_map( + tff_cast_fn, clipped_value + ) return inner_agg_process.next(state, clipped_value) @@ -225,7 +225,7 @@ def distinct(): def _check_is_integer_struct(value_type, label): - if not type_analysis.is_structure_of_integers(value_type): + if not federated_language.framework.is_structure_of_integers(value_type): raise TypeError( f'Component dtypes of `{label}` must all be integers. ' f'Found {repr(value_type)}.' @@ -233,7 +233,7 @@ def _check_is_integer_struct(value_type, label): def _check_is_tensor_type(value, label): - if not isinstance(value, computation_types.TensorType): + if not isinstance(value, federated_language.TensorType): raise TypeError( f'Expected `{label}` to be `TensorType`. Found type: {repr(value)}' ) diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory_test.py b/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory_test.py index b79547c595..012bd577f2 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory_test.py +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/clipping_factory_test.py @@ -15,14 +15,13 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.analytics.hierarchical_histogram import clipping_factory from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -40,22 +39,24 @@ def test_clip(self, clip_mechanism): clip_mechanism, 1 ) self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory) - value_type = computation_types.TensorType(np.int32, shape=(2,)) + value_type = federated_language.TensorType(np.int32, shape=(2,)) process = clip_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - param_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - value_type, placements.SERVER + result_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER ) - expected_state_type = computation_types.FederatedType((), placements.SERVER) - expected_measurements_type = computation_types.FederatedType( - (), placements.SERVER + expected_state_type = federated_language.FederatedType( + (), federated_language.SERVER + ) + expected_measurements_type = federated_language.FederatedType( + (), federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -64,7 +65,7 @@ def test_clip(self, clip_mechanism): ) ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type ), @@ -86,7 +87,7 @@ class ClippingSumFactoryExecutionTest(tf.test.TestCase, parameterized.TestCase): def test_raises_value_error( self, clip_mechanism, max_records_per_user, value_type ): - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaises(ValueError): clip_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism=clip_mechanism, @@ -99,7 +100,7 @@ def test_raises_value_error( ('float_value_type', (np.float32, (2,))), ) def test_raises_type_error(self, value_type): - value_type = computation_types.to_type(value_type) + value_type = federated_language.to_type(value_type) with self.assertRaises(TypeError): clip_factory = clipping_factory.HistogramClippingSumFactory() clip_factory.create(value_type) @@ -155,7 +156,7 @@ def test_sub_sample_clip_factory( clip_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism='sub-sampling', max_records_per_user=max_records_per_user ) - outer_value_type = computation_types.TensorType( + outer_value_type = federated_language.TensorType( np.int32, shape=(value_shape,) ) process = clip_factory.create(outer_value_type) @@ -189,7 +190,7 @@ def test_distinct_clip_factory( clip_mechanism='distinct', max_records_per_user=max_records_per_user ) - value_type = computation_types.TensorType(np.int32, shape=(value_shape,)) + value_type = federated_language.TensorType(np.int32, shape=(value_shape,)) process = clip_factory.create(value_type) state = process.initialize() diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_factory_test.py b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_factory_test.py index 5296583cec..3b87062e81 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_factory_test.py +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_factory_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -23,8 +24,6 @@ from tensorflow_federated.python.analytics.hierarchical_histogram import build_tree_from_leaf from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory as hihi_factory from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process @@ -52,27 +51,27 @@ def test_no_noise_tree_aggregation( ) ) self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - query_state_type = computation_types.StructType([ - ('arity', computation_types.TensorType(np.int32)), - ('inner_query_state', computation_types.StructType([])), + query_state_type = federated_language.StructType([ + ('arity', federated_language.TensorType(np.int32)), + ('inner_query_state', federated_language.StructType([])), ]) - query_metrics_type = computation_types.StructType([]) + query_metrics_type = federated_language.StructType([]) - dp_event_type = computation_types.StructType([ - ('module_name', computation_types.TensorType(np.str_)), - ('class_name', computation_types.TensorType(np.str_)), + dp_event_type = federated_language.StructType([ + ('module_name', federated_language.TensorType(np.str_)), + ('class_name', federated_language.TensorType(np.str_)), ]) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( differential_privacy.DPAggregatorState( query_state_type, (), dp_event_type, np.bool_ ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -90,36 +89,36 @@ def test_no_noise_tree_aggregation( ) else: expected_measurements_dp = () - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( dp_query_metrics=query_metrics_type, dp=expected_measurements_dp ), - placements.SERVER, + federated_language.SERVER, ) tree_depth = hihi_factory._tree_depth(value_shape, arity) flat_tree_shape = (arity**tree_depth - 1) // (arity - 1) - result_value_type = computation_types.to_type( + result_value_type = federated_language.to_type( collections.OrderedDict([ ( 'flat_values', - computation_types.to_type((np.int32, (flat_tree_shape,))), + federated_language.to_type((np.int32, (flat_tree_shape,))), ), ('nested_row_splits', [(np.int64, (tree_depth + 1,))]), ]) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) - expected_next_type = computation_types.FunctionType( + value_type = federated_language.to_type((np.int32, (value_shape,))) + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - result_value_type, placements.SERVER + result=federated_language.FederatedType( + result_value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -147,35 +146,35 @@ def test_central_gaussian_tree_aggregation( ) ) self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - query_state_type = computation_types.StructType([ - ('arity', computation_types.TensorType(np.int32)), + query_state_type = federated_language.StructType([ + ('arity', federated_language.TensorType(np.int32)), ( 'inner_query_state', - computation_types.StructType([ - ('l2_norm_clip', computation_types.TensorType(np.float32)), - ('stddev', computation_types.TensorType(np.float32)), + federated_language.StructType([ + ('l2_norm_clip', federated_language.TensorType(np.float32)), + ('stddev', federated_language.TensorType(np.float32)), ]), ), ]) - query_metrics_type = computation_types.StructType([]) + query_metrics_type = federated_language.StructType([]) # template_type is not derived from value_type in this test because the # outer factory converts the ints to floats before they reach the query. - dp_event_type = computation_types.StructType([ - ('module_name', computation_types.TensorType(np.str_)), - ('class_name', computation_types.TensorType(np.str_)), + dp_event_type = federated_language.StructType([ + ('module_name', federated_language.TensorType(np.str_)), + ('class_name', federated_language.TensorType(np.str_)), ]) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( differential_privacy.DPAggregatorState( query_state_type, (), dp_event_type, np.bool_ ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -193,35 +192,35 @@ def test_central_gaussian_tree_aggregation( ) else: expected_measurements_dp = () - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( dp_query_metrics=query_metrics_type, dp=expected_measurements_dp ), - placements.SERVER, + federated_language.SERVER, ) tree_depth = hihi_factory._tree_depth(value_shape, arity) flat_tree_shape = (arity**tree_depth - 1) // (arity - 1) - result_value_type = computation_types.to_type( + result_value_type = federated_language.to_type( collections.OrderedDict([ ( 'flat_values', - computation_types.to_type((np.float32, (flat_tree_shape,))), + federated_language.to_type((np.float32, (flat_tree_shape,))), ), ('nested_row_splits', [(np.int64, (tree_depth + 1,))]), ]) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) - expected_next_type = computation_types.FunctionType( + value_type = federated_language.to_type((np.int32, (value_shape,))) + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - result_value_type, placements.SERVER + result=federated_language.FederatedType( + result_value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -249,33 +248,33 @@ def test_distributed_discrete_gaussian_tree_aggregation( ) ) self.assertIsInstance(agg_factory, factory.UnweightedAggregationFactory) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - query_state_type = computation_types.StructType([ - ('arity', computation_types.TensorType(np.int32)), + query_state_type = federated_language.StructType([ + ('arity', federated_language.TensorType(np.int32)), ( 'inner_query_state', - computation_types.StructType([ - ('l2_norm_bound', computation_types.TensorType(np.float32)), - ('local_stddev', computation_types.TensorType(np.float32)), + federated_language.StructType([ + ('l2_norm_bound', federated_language.TensorType(np.float32)), + ('local_stddev', federated_language.TensorType(np.float32)), ]), ), ]) - query_metrics_type = computation_types.StructType([]) + query_metrics_type = federated_language.StructType([]) - dp_event_type = computation_types.StructType([ - ('module_name', computation_types.TensorType(np.str_)), - ('class_name', computation_types.TensorType(np.str_)), + dp_event_type = federated_language.StructType([ + ('module_name', federated_language.TensorType(np.str_)), + ('class_name', federated_language.TensorType(np.str_)), ]) - server_state_type = computation_types.FederatedType( + server_state_type = federated_language.FederatedType( differential_privacy.DPAggregatorState( query_state_type, (), dp_event_type, np.bool_ ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=server_state_type ) self.assertTrue( @@ -283,34 +282,34 @@ def test_distributed_discrete_gaussian_tree_aggregation( expected_initialize_type ) ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict(dp_query_metrics=query_metrics_type, dp=()), - placements.SERVER, + federated_language.SERVER, ) tree_depth = hihi_factory._tree_depth(value_shape, arity) flat_tree_shape = (arity**tree_depth - 1) // (arity - 1) - result_value_type = computation_types.to_type( + result_value_type = federated_language.to_type( collections.OrderedDict([ ( 'flat_values', - computation_types.to_type((np.int32, (flat_tree_shape,))), + federated_language.to_type((np.int32, (flat_tree_shape,))), ), ('nested_row_splits', [(np.int64, (tree_depth + 1,))]), ]) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) - expected_next_type = computation_types.FunctionType( + value_type = federated_language.to_type((np.int32, (value_shape,))) + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=server_state_type, - value=computation_types.FederatedType( - value_type, placements.CLIENTS + value=federated_language.FederatedType( + value_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( state=server_state_type, - result=computation_types.FederatedType( - result_value_type, placements.SERVER + result=federated_language.FederatedType( + result_value_type, federated_language.SERVER ), measurements=expected_measurements_type, ), @@ -506,7 +505,7 @@ def test_no_noise_tree_aggregation_wo_clip( enable_secure_sum=enable_secure_sum, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() @@ -561,7 +560,7 @@ def test_no_noise_tree_aggregation_w_clip( enable_secure_sum=enable_secure_sum, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() @@ -621,7 +620,7 @@ def test_central_gaussian_tree_aggregation_wo_clip( enable_secure_sum=enable_secure_sum, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() @@ -684,7 +683,7 @@ def test_central_gaussian_tree_aggregation_w_clip( enable_secure_sum=enable_secure_sum, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() @@ -744,7 +743,7 @@ def test_distributed_discrete_gaussian_tree_aggregation_wo_clip( enable_secure_sum=True, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() @@ -806,7 +805,7 @@ def test_distributed_discrete_gaussian_tree_aggregation_w_clip( enable_secure_sum=True, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() @@ -865,7 +864,7 @@ def test_distributed_discrete_gaussian_tree_aggregation_no_overflow( enable_secure_sum=True, ) ) - value_type = computation_types.to_type((np.int32, (value_shape,))) + value_type = federated_language.to_type((np.int32, (value_shape,))) process = agg_factory.create(value_type) state = process.initialize() diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib.py b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib.py index d8f8b25f20..af67c0ee89 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib.py +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib.py @@ -17,16 +17,13 @@ from typing import Any import attrs +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.analytics.hierarchical_histogram import clipping_factory from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_factory from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process @@ -210,7 +207,7 @@ def build_hierarchical_histogram_computation( ) @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.float32) + federated_language.SequenceType(np.float32) ) def client_work(client_data): return _discretized_histogram_counts( @@ -231,23 +228,23 @@ def client_work(client_data): process = agg_factory.create(client_work.type_signature.result) - @federated_computation.federated_computation( - computation_types.FederatedType( - client_work.type_signature.parameter, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + client_work.type_signature.parameter, federated_language.CLIENTS ) ) def hierarchical_histogram_computation(federated_client_data): - round_timestamp = intrinsics.federated_eval( + round_timestamp = federated_language.federated_eval( tensorflow_computation.tf_computation( lambda: tf.cast(tf.timestamp(), np.int64) ), - placements.SERVER, + federated_language.SERVER, ) - client_histogram = intrinsics.federated_map( + client_histogram = federated_language.federated_map( client_work, federated_client_data ) - server_output = intrinsics.federated_zip( + server_output = federated_language.federated_zip( ServerOutput( process.next(process.initialize(), client_histogram).result, round_timestamp, @@ -393,17 +390,19 @@ def initialize(): } return ServerOutput(initial_hierarchical_histogram, initial_timestamp) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_eval(initialize, placements.SERVER) + return federated_language.federated_eval( + initialize, federated_language.SERVER + ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, parameter_type_signature ) def next_fn(_, client_data): - return one_round_computation(client_data), intrinsics.federated_value( - (), placements.SERVER - ) + return one_round_computation( + client_data + ), federated_language.federated_value((), federated_language.SERVER) return iterative_process.IterativeProcess(init_fn, next_fn) diff --git a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py index 01f4d30dda..677b5eeb1c 100644 --- a/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py +++ b/tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_lib_test.py @@ -15,12 +15,12 @@ from unittest import mock from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.analytics.hierarchical_histogram import hierarchical_histogram_lib as hihi from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.test import static_assert MOCK_TIME_SECONDS = 314159.2653 EXPECTED_ROUND_TIMESTAMP = 314159 @@ -1064,7 +1064,9 @@ def test_secure_sum(self, dp_mechanism): dp_mechanism=dp_mechanism, enable_secure_sum=True, ) - static_assert.assert_not_contains_unsecure_aggregation(hihi_computation) + federated_language.framework.assert_not_contains_unsecure_aggregation( + hihi_computation + ) @mock.patch('tensorflow.timestamp') def test_round_timestamp(self, timestamp_mock): diff --git a/tensorflow_federated/python/common_libs/BUILD b/tensorflow_federated/python/common_libs/BUILD index 06bdd7ae09..a1ff7d0a3c 100644 --- a/tensorflow_federated/python/common_libs/BUILD +++ b/tensorflow_federated/python/common_libs/BUILD @@ -27,7 +27,7 @@ py_library( py_library( name = "async_utils", srcs = ["async_utils.py"], - deps = [":tracing"], + deps = ["@federated_language//federated_language"], ) py_library( @@ -79,48 +79,8 @@ py_test( deps = [":py_typecheck"], ) -py_library( - name = "retrying", - srcs = ["retrying.py"], - deps = [ - ":py_typecheck", - ], -) - -py_test( - name = "retrying_test", - size = "small", - srcs = ["retrying_test.py"], - deps = [":retrying"], -) - -py_library( - name = "serializable", - srcs = ["serializable.py"], -) - py_library( name = "structure", srcs = ["structure.py"], - deps = [":py_typecheck"], -) - -py_test( - name = "structure_test", - size = "small", - srcs = ["structure_test.py"], - deps = [":structure"], -) - -py_library( - name = "tracing", - srcs = ["tracing.py"], - deps = [":py_typecheck"], -) - -py_test( - name = "tracing_test", - size = "small", - srcs = ["tracing_test.py"], - deps = [":tracing"], + deps = ["@federated_language//federated_language/common_libs:structure"], ) diff --git a/tensorflow_federated/python/common_libs/async_utils.py b/tensorflow_federated/python/common_libs/async_utils.py index c1632b04e8..2377359bdd 100644 --- a/tensorflow_federated/python/common_libs/async_utils.py +++ b/tensorflow_federated/python/common_libs/async_utils.py @@ -16,7 +16,7 @@ import asyncio import threading -from tensorflow_federated.python.common_libs import tracing +import federated_language class AsyncThreadRunner: @@ -42,7 +42,7 @@ class AsyncThreadRunner: def __init__(self): self._event_loop = asyncio.new_event_loop() self._event_loop.set_task_factory( - tracing.propagate_trace_context_task_factory + federated_language.framework.propagate_trace_context_task_factory ) def target_fn(): diff --git a/tensorflow_federated/python/common_libs/retrying.py b/tensorflow_federated/python/common_libs/retrying.py deleted file mode 100644 index fb5add03bd..0000000000 --- a/tensorflow_federated/python/common_libs/retrying.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Library of pure-python retrying decorators.""" - -import asyncio -from collections.abc import Callable -import functools -import inspect -import time -from typing import Union - -from tensorflow_federated.python.common_libs import py_typecheck - - -def retry( - fn=None, - *, - retry_on_exception_filter: Callable[[Exception], bool] = lambda x: True, - retry_on_result_filter: Callable[[object], bool] = lambda x: False, - wait_max_ms: Union[float, int] = 30000, - wait_multiplier: Union[float, int] = 2, -): - """Pure Python decorator that retries functions or coroutine functions. - - `retry` starts at some delay between function invocations, and backs - off exponentialy with factor `wait_multiplier` until the max of - `max_wait_ms`, at which point `retry` will continue to retry `fn` at intervals - of `max_wait_ms` until `retry_on_exception_filter` returns `False`. - - Args: - fn: Optional Python function or coroutine function to wrap in retrying - logic. If None, `retry` will return a callable which decorates a function - or corofunc to be passed later. - retry_on_exception_filter: Function accepting a Python `Exception`, and - returning a Boolean indicating whether or not to retry the invocation. - retry_on_result_filter: Function accepting a function result or coroutine - function result, and returning a Boolean indicating whether or not to - retry the invocation. - wait_max_ms: Maximum time `retry` is allowed to wait between invocations of - `fn`, in milliseconds. Must be positive. - wait_multiplier: Number determining the exponential backoff multiplier to - use. Must be positive. - - Returns: - In the case that `fn` is provided, a decorated version of `fn` respecting - the semantics above. If `fn` is not provided, returns a callable which can - be used to decorate a function or coroutine function at a later time. - """ - py_typecheck.check_type(wait_max_ms, (float, int)) - py_typecheck.check_type(wait_multiplier, (float, int)) - if not inspect.isfunction(retry_on_exception_filter): - raise TypeError( - 'Expected function to be passed as retry_on_exception_filter; ' - 'encountered {} of type {}.'.format( - retry_on_exception_filter, type(retry_on_exception_filter) - ) - ) - if not inspect.isfunction(retry_on_result_filter): - raise TypeError( - 'Expected function to be passed as retry_on_result_filter; ' - 'encountered {} of type {}.'.format( - retry_on_result_filter, type(retry_on_result_filter) - ) - ) - if wait_max_ms <= 0: - raise ValueError( - 'wait_max_ms required to be positive; encountered value {}.'.format( - wait_max_ms - ) - ) - if wait_multiplier <= 0: - raise ValueError( - 'wait_multiplier required to be positive; encountered value {}.'.format( - wait_multiplier - ) - ) - - if fn is None: - # Called with arguments; delay decoration until `fn` is passed in. - return functools.partial( - retry, - retry_on_exception_filter=retry_on_exception_filter, - retry_on_result_filter=retry_on_result_filter, - wait_max_ms=wait_max_ms, - wait_multiplier=wait_multiplier, - ) - - if inspect.iscoroutinefunction(fn): - # Similar to the logic in tracing.py, we case on corofunction versus vanilla - # function. - - @functools.wraps(fn) - async def retry_coro_fn(*args, **kwargs): - retry_wait_ms = 1.0 - - while True: - try: - result = await fn(*args, **kwargs) - if retry_on_result_filter(result): - retry_wait_ms = min(wait_max_ms, retry_wait_ms * wait_multiplier) - # time.sleep takes arguments in seconds. - await asyncio.sleep(retry_wait_ms / 1000) - continue - else: - return result - except Exception as e: # pylint: disable=broad-except - if not retry_on_exception_filter(e): - raise e - retry_wait_ms = min(wait_max_ms, retry_wait_ms * wait_multiplier) - # asyncio.sleep takes arguments in seconds. - await asyncio.sleep(retry_wait_ms / 1000) - - return retry_coro_fn - - elif inspect.isfunction(fn): - # Vanilla Python function; decorate as normal. - - @functools.wraps(fn) - def retry_fn(*args, **kwargs): - retry_wait_ms = 1.0 - - while True: - try: - result = fn(*args, **kwargs) - if retry_on_result_filter(result): - retry_wait_ms = min(wait_max_ms, retry_wait_ms * wait_multiplier) - # time.sleep takes arguments in seconds. - time.sleep(retry_wait_ms / 1000) - continue - else: - return result - except Exception as e: # pylint: disable=broad-except - if not retry_on_exception_filter(e): - raise e - retry_wait_ms = min(wait_max_ms, retry_wait_ms * wait_multiplier) - # time.sleep takes arguments in seconds. - time.sleep(retry_wait_ms / 1000) - - return retry_fn - - else: - raise TypeError( - 'Retrying expects Python function or coroutine function; ' - 'passed {} of type {}.'.format(fn, type(fn)) - ) diff --git a/tensorflow_federated/python/common_libs/retrying_test.py b/tensorflow_federated/python/common_libs/retrying_test.py deleted file mode 100644 index 98671f10ff..0000000000 --- a/tensorflow_federated/python/common_libs/retrying_test.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from unittest import mock - -from absl.testing import absltest - -from tensorflow_federated.python.common_libs import retrying - - -class RetryingArgValidationtest(absltest.TestCase): - - def test_raises_non_function(self): - with self.assertRaises(TypeError): - retrying.retry(fn=0) - - def test_raises_non_function_exception_filter(self): - with self.assertRaises(TypeError): - retrying.retry(fn=lambda x: x, retry_on_exception_filter=0) - - def test_raises_non_function_result_filter(self): - with self.assertRaises(TypeError): - retrying.retry(fn=lambda x: x, retry_on_result_filter=0) - - def test_raises_complex_wait_multiplier(self): - with self.assertRaises(TypeError): - retrying.retry(fn=lambda x: x, wait_multiplier=1j) - - def test_raises_complex_max_wait_ms(self): - with self.assertRaises(TypeError): - retrying.retry(fn=lambda x: x, wait_max_ms=1j) - - def test_raises_zero_wait_multiplier(self): - with self.assertRaises(ValueError): - retrying.retry(fn=lambda x: x, wait_multiplier=0) - - def test_raises_zero_max_wait_ms(self): - with self.assertRaises(ValueError): - retrying.retry(fn=lambda x: x, wait_max_ms=0) - - -class CountInvocations: - - def __init__( - self, - n_invocations_to_raise: int, - error_to_raise: Exception, - return_value: object, - ): - self._n_invocations_to_raise = n_invocations_to_raise - self._error_to_raise = error_to_raise - self._return_value = return_value - self._n_invocations = 0 - - @property - def n_invocations(self): - return self._n_invocations - - def __call__(self, *args, **kwargs): - del args, kwargs # Unused - self._n_invocations += 1 - if self._n_invocations <= self._n_invocations_to_raise: - raise self._error_to_raise - return self._return_value - - -class RetryingFunctionTest(absltest.TestCase): - - def test_standalone_decorator_always_retries(self): - expected_return_val = 0 - expected_num_invocations = 3 - count_invocations_callable = CountInvocations( - expected_num_invocations, TypeError('Error'), expected_return_val - ) - - @retrying.retry - def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - return_val = invoke_callable() - - self.assertEqual(return_val, expected_return_val) - # Final call succeeds - self.assertEqual( - count_invocations_callable.n_invocations, expected_num_invocations + 1 - ) - - def test_error_filter_raises_wrong_error_type(self): - count_invocations_callable = CountInvocations(1, TypeError('Error'), 0) - - @retrying.retry( - retry_on_exception_filter=lambda e: isinstance(e, ValueError) - ) - def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - with self.assertRaises(TypeError): - invoke_callable() - - def test_error_filter_called_with_raised_err(self): - error = TypeError('error') - expected_result = 1 - - count_invocations_callable = CountInvocations(1, error, 1) - mock_callable = mock.MagicMock(return_value=True) - - def err_filter(*args): - return mock_callable(*args) - - @retrying.retry(retry_on_exception_filter=err_filter) - def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - result = invoke_callable() - self.assertEqual(result, expected_result) - mock_callable.assert_called_once_with(error) - - def test_result_filter_not_incur_retry(self): - expected_return_val = 0 - expected_num_invocations = 3 - count_invocations_callable = CountInvocations( - expected_num_invocations, TypeError('Error'), expected_return_val - ) - mock_callable = mock.MagicMock(return_value=False) - - def result_filter(*args): - return mock_callable(*args) - - @retrying.retry(retry_on_result_filter=result_filter) - def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - return_val = invoke_callable() - - self.assertEqual(return_val, expected_return_val) - # Final call succeeds - self.assertEqual( - count_invocations_callable.n_invocations, expected_num_invocations + 1 - ) - - def test_result_filter_incur_retry(self): - expected_return_val = 0 - expected_num_invocations = 3 - count_invocations_callable = CountInvocations( - expected_num_invocations, TypeError('Error'), expected_return_val - ) - mock_callable = mock.Mock() - mock_callable.side_effect = [True, False] - - def result_filter(*args): - return mock_callable(*args) - - @retrying.retry(retry_on_result_filter=result_filter) - def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - return_val = invoke_callable() - - self.assertEqual(return_val, expected_return_val) - # Final call succeeds - self.assertEqual( - count_invocations_callable.n_invocations, expected_num_invocations + 2 - ) - - -class RetryingCoroFunctionTest(absltest.TestCase): - - def setUp(self): - self._loop = asyncio.new_event_loop() - super().setUp() - - def _run_sync(self, fn, args=None): - return self._loop.run_until_complete(fn(args)) - - def test_standalone_decorator_always_retries(self): - expected_return_val = 0 - expected_num_invocations = 3 - count_invocations_callable = CountInvocations( - expected_num_invocations, TypeError('Error'), expected_return_val - ) - - @retrying.retry - async def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - return_val = self._run_sync(invoke_callable) - - self.assertEqual(return_val, expected_return_val) - # Final call succeeds - self.assertEqual( - count_invocations_callable.n_invocations, expected_num_invocations + 1 - ) - - def test_error_filter_raises_wrong_error_type(self): - count_invocations_callable = CountInvocations(1, TypeError('Error'), 0) - - @retrying.retry( - retry_on_exception_filter=lambda e: isinstance(e, ValueError) - ) - async def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - with self.assertRaises(TypeError): - self._run_sync(invoke_callable) - - def test_error_filter_called_with_raised_err(self): - error = TypeError('error') - expected_result = 1 - - count_invocations_callable = CountInvocations(1, error, 1) - mock_callable = mock.MagicMock(return_value=True) - - def err_filter(*args): - return mock_callable(*args) - - @retrying.retry(retry_on_exception_filter=err_filter) - async def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - result = self._run_sync(invoke_callable) - self.assertEqual(result, expected_result) - mock_callable.assert_called_once_with(error) - - def test_result_filter_not_incur_retry(self): - expected_result = 0 - expected_num_invocations = 3 - count_invocations_callable = CountInvocations( - expected_num_invocations, TypeError('Error'), expected_result - ) - mock_callable = mock.MagicMock(return_value=False) - - def result_filter(*args): - return mock_callable(*args) - - @retrying.retry(retry_on_result_filter=result_filter) - async def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - result = self._run_sync(invoke_callable) - - self.assertEqual(result, expected_result) - # Final call succeeds - self.assertEqual( - count_invocations_callable.n_invocations, expected_num_invocations + 1 - ) - - def test_result_filter_incur_retry(self): - expected_result = 0 - expected_num_invocations = 3 - count_invocations_callable = CountInvocations( - expected_num_invocations, TypeError('Error'), expected_result - ) - mock_callable = mock.Mock() - mock_callable.side_effect = [True, False] - - def result_filter(*args): - return mock_callable(*args) - - @retrying.retry(retry_on_result_filter=result_filter) - async def invoke_callable(*args, **kwargs): - return count_invocations_callable(*args, **kwargs) - - result = self._run_sync(invoke_callable) - - self.assertEqual(result, expected_result) - # Final call succeeds - self.assertEqual( - count_invocations_callable.n_invocations, expected_num_invocations + 2 - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/common_libs/serializable.py b/tensorflow_federated/python/common_libs/serializable.py deleted file mode 100644 index fe89caa701..0000000000 --- a/tensorflow_federated/python/common_libs/serializable.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2023, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines an abstract interface for objects that can be serialized.""" - -import abc - - -class Serializable(abc.ABC): - - @classmethod - @abc.abstractmethod - def from_bytes(cls, buffer: bytes) -> 'Serializable': - """Deserializes the object from bytes.""" - raise NotImplementedError - - @abc.abstractmethod - def to_bytes(self) -> bytes: - """Serializes the object to bytes.""" - raise NotImplementedError diff --git a/tensorflow_federated/python/common_libs/structure.py b/tensorflow_federated/python/common_libs/structure.py index b0eb26ff8f..faa1a05597 100644 --- a/tensorflow_federated/python/common_libs/structure.py +++ b/tensorflow_federated/python/common_libs/structure.py @@ -13,687 +13,20 @@ # limitations under the License. """Container for structures with named and/or unnamed fields.""" -import collections -from collections.abc import Callable, Iterable, Iterator, Mapping -import typing -from typing import Generic, Optional, TypeVar, Union - -import attrs -import tree - -from tensorflow_federated.python.common_libs import py_typecheck - - -_T = TypeVar('_T') -_U = TypeVar('_U') - - -class Struct(Generic[_T]): - """Represents a struct-like structure with named and/or unnamed fields. - - `Struct`s are similar to `collections.namedtuple` in that their elements can - be accessed by name or by index. However, `Struct`s provide a performance - improvement over `collections.namedtuple` by using a single class to - represent values with many different possible structures, rather than - creating a brand new class for every new instance. - - `Struct`s are commonly used inside Tensorflow Federated as a standard - intermediate representation of other structure types, including `list`s, - `tuple`s, `dict`s, `namedtuple`s, and `attr.s` classes. - - Example: - - ```python - x = Struct([('foo', 10), (None, 20), ('bar', 30)]) - - len(x) == 3 - x[0] == 10 - x[1] == 20 - x[2] == 30 - list(iter(x)) == [10, 20, 30] - dir(x) == ['bar', 'foo'] - x.foo == 10 - x['bar'] == 30 - ``` - - Note that field names are optional, allowing `Struct` to be used like an - ordinary positional tuple. - """ - - __slots__ = ( - '_hash', - '_element_array', - '_name_to_index', - '_name_array', - '_elements_cache', - ) - - @classmethod - def named(cls, **kwargs: _T) -> 'Struct': - """Constructs a new `Struct` with all named elements.""" - return cls(tuple(kwargs.items())) - - @classmethod - def unnamed(cls, *args: _T) -> 'Struct': - """Constructs a new `Struct` with all unnamed elements.""" - return cls(tuple((None, v) for v in args)) - - def __init__(self, elements: Iterable[tuple[Optional[str], _T]]): - """Constructs a new `Struct` with the given elements. - - Args: - elements: An iterable of element specifications, each being a pair - consisting of the element name (either `str`, or `None`), and the - element value. The order is significant. - - Raises: - TypeError: if the `elements` are not a list, or if any of the items on - the list is not a pair with a string at the first position. - """ - values = [] - names = [] - name_to_index = {} - reserved_names = frozenset(('_asdict',) + Struct.__slots__) - for idx, e in enumerate(elements): - if not py_typecheck.is_name_value_pair(e): - raise TypeError( - 'Expected every item on the list to be a pair in which the first ' - 'element is a string, found {!r}.'.format(e) - ) - name, value = e - if name in reserved_names: - raise ValueError( - 'The names in {} are reserved. You passed the name {}.'.format( - reserved_names, name - ) - ) - elif name in name_to_index: - raise ValueError( - '`Struct` does not support duplicated names, found {}.'.format( - [e[0] for e in elements] - ) - ) - names.append(name) - values.append(value) - if name is not None: - name_to_index[name] = idx - self._element_array = tuple(values) - self._name_to_index = name_to_index - self._name_array = names - self._hash = None - self._elements_cache = None - - def _elements(self) -> list[tuple[Optional[str], _T]]: - if self._elements_cache is None: - self._elements_cache = list(zip(self._name_array, self._element_array)) - return self._elements_cache - - def __len__(self) -> int: - return len(self._element_array) - - def __iter__(self) -> Iterator[_T]: - return iter(self._element_array) - - def __dir__(self) -> list[str]: - """The list of names. - - IMPORTANT: `len(self)` may be greater than `len(dir(self))`, since field - names are not required by `Struct`. - - IMPORTANT: the Python `dir()` built-in sorts the list returned by this - method. - - Returns: - A `list` of `str`. - """ - return list(self._name_to_index.keys()) - - def __getitem__(self, key: Union[int, str, slice]) -> _T: - if isinstance(key, str): - return self.__getattr__(key) - if isinstance(key, int): - if key < 0 or key >= len(self._element_array): - raise IndexError( - 'Element index {} is out of range, `Struct` has {} elements.' - .format(key, len(self._element_array)) - ) - return self._element_array[key] - - def __getattr__(self, name: str) -> _T: - if name not in self._name_to_index: - raise AttributeError( - 'The `Struct` of length {:d} does not have named field "{!s}". ' - 'Fields (up to first 10): {!s}'.format( - len(self._element_array), - name, - list(self._name_to_index.keys())[:10], - ) - ) - return self._element_array[self._name_to_index[name]] - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Struct): - return NotImplemented - return ( - self._element_array, - self._name_array, - ) == ( - other._element_array, - other._name_array, - ) - - def __ne__(self, other: object) -> bool: - return not self == other - - def __repr__(self) -> str: - return 'Struct([{}])'.format( - ', '.join('({!r}, {!r})'.format(n, v) for n, v in iter_elements(self)) - ) - - def __str__(self) -> str: - def _element_str(element: tuple[Optional[str], object]) -> str: - name, value = element - if name is not None: - return '{}={}'.format(name, value) - return str(value) - - return '<{}>'.format(','.join(_element_str(e) for e in iter_elements(self))) - - def __hash__(self) -> int: - if self._hash is None: - self._hash = hash(( - 'Struct', # salting to avoid type mismatch. - self._element_array, - tuple(self._name_array), - )) - return self._hash - - def _asdict( - self, recursive: bool = False - ) -> collections.OrderedDict[str, _T]: - """Returns an `collections.OrderedDict` mapping field names to their values. - - Args: - recursive: Whether to convert nested `Struct`s recursively. - """ - return to_odict(self, recursive=recursive) - - -def name_list(struct: Struct) -> list[str]: - """Returns a `list` of the names of the named fields in `struct`. - - Args: - struct: An instance of `Struct`. - - Returns: - The list of string names for the fields that are named. Names appear in - order, skipping names that are `None`. - """ - names = struct._name_array # pylint: disable=protected-access - return [n for n in names if n is not None] - - -def name_list_with_nones(struct: Struct) -> list[Optional[str]]: - """Returns an iterator over the names of all fields in `struct`.""" - return struct._name_array # pylint: disable=protected-access - - -def to_elements(struct: Struct[_T]) -> list[tuple[Optional[str], _T]]: - """Retrieves the list of (name, value) pairs from a `Struct`. - - Modeled as a module function rather than a method of `Struct` to avoid - naming conflicts with the tuple attributes, and so as not to expose the user - to this implementation-oriented functionality. - - Args: - struct: An instance of `Struct`. - - Returns: - The list of (name, value) pairs in which names can be None. Identical to - the format that's accepted by the tuple constructor. - - Raises: - TypeError: if the argument is not an `Struct`. - """ - return struct._elements().copy() # pylint: disable=protected-access - - -def iter_elements(struct: Struct[_T]) -> Iterator[tuple[Optional[str], _T]]: - """Returns an iterator over (name, value) pairs from a `Struct`. - - Modeled as a module function rather than a method of `Struct` to avoid - naming conflicts with the tuple attributes, and so as not to expose the user - to this implementation-oriented functionality. - - Args: - struct: An instance of `Struct`. - - Returns: - An iterator of 2-tuples of name, value pairs, representing the elements of - `struct`. - - Raises: - TypeError: if the argument is not an `Struct`. - """ - return iter(struct._elements()) # pylint: disable=protected-access - - -def to_odict( - struct: Struct[_T], recursive: bool = False -) -> collections.OrderedDict[str, _T]: - """Returns `struct` as an `collections.OrderedDict`, if possible. - - Args: - struct: An `Struct`. - recursive: Whether to convert nested `Struct`s recursively. - - Raises: - ValueError: If the `Struct` contains unnamed elements. - """ - - def _to_odict( - elements: list[tuple[Optional[str], _T]] - ) -> collections.OrderedDict[str, _T]: - for name, _ in elements: - if name is None: - raise ValueError( - 'Cannot convert an `Struct` with unnamed entries to a ' - '`collections.OrderedDict`: {}'.format(struct) - ) - elements = typing.cast(list[tuple[str, _T]], elements) - return collections.OrderedDict(elements) - - if recursive: - return _to_container_recursive(struct, _to_odict) - else: - return _to_odict(to_elements(struct)) - - -def to_odict_or_tuple( - struct: Struct[_T], recursive: bool = True -) -> Union[collections.OrderedDict[str, _T], tuple[_T, ...]]: - """Returns `struct` as an `collections.OrderedDict` or `tuple`, if possible. - - If all elements of `struct` have names, convert `struct` to an - `collections.OrderedDict`. If no element has a name, convert `struct` to a - `tuple`. If - `struct` has both named and unnamed elements, raise an error. - - Args: - struct: A `Struct`. - recursive: Whether to convert nested `Struct`s recursively. - - Raises: - ValueError: If `struct` (or any nested `Struct` when `recursive=True`) - contains both named and unnamed elements. - """ - - def _to_odict_or_tuple( - elements: list[tuple[Optional[str], _T]] - ) -> Union[collections.OrderedDict[str, _T], tuple[_T, ...]]: - fields_are_named = tuple(name is not None for name, _ in elements) - if any(fields_are_named): - if not all(fields_are_named): - raise ValueError( - 'Cannot convert a `Struct` with both named and unnamed ' - 'entries to an collections.OrderedDict or tuple: {!r}'.format( - struct - ) - ) - elements = typing.cast(list[tuple[str, _T]], elements) - return collections.OrderedDict(elements) - else: - return tuple(value for _, value in elements) - - if recursive: - return _to_container_recursive(struct, _to_odict_or_tuple) - else: - return _to_odict_or_tuple(to_elements(struct)) - - -def flatten(struct: object) -> list[object]: - """Returns a list of values in a possibly recursively nested `Struct`. - - Note: _This implementation is not compatible with the approach of - `tf.nest.flatten`, which enforces lexical order for - `collections.OrderedDict`s. - - Args: - struct: A `Struct`, possibly recursively nested, or a non-`Struct` element - that can be packed with `tf.nest.flatten`. If `struct` has - non-`Struct`-typed fields which should be flattened further, they should - not contain inner `Structs`, as these will not be flattened (e.g. - `Struct([('a', collections.OrderedDict(b=Struct([('c', 5)])))])` would not - be valid). - - Returns: - The list of leaf values in the `Struct`. - """ - if not isinstance(struct, Struct): - return tree.flatten(struct) - else: - result = [] - for _, v in iter_elements(struct): - result.extend(flatten(v)) - return result - - -def pack_sequence_as( - structure: Struct[_T], flat_sequence: list[_T] -) -> Struct[_T]: - """Returns a list of values in a possibly recursively nested `Struct`. - - Args: - structure: A `Struct`, possibly recursively nested. - flat_sequence: A flat Python list of values. - - Returns: - A `Struct` nested the same way as `structure`, but with leaves - replaced with `flat_sequence` such that when flatten, it yields a list - with the same contents as `flat_sequence`. - """ - - def _pack( - structure, flat_sequence: list[_T], position: int - ) -> tuple[Struct[_T], int]: - """Pack a leaf element or recurvisely iterate over an `Struct`.""" - if not isinstance(structure, Struct): - # Ensure that our leaf values are not structures. - if isinstance( - structure, (list, dict, py_typecheck.SupportsNamedTuple) - ) or attrs.has(type(structure)): - raise TypeError( - 'Cannot pack sequence into type {!s}, only structures of ' - '`Struct` are supported, found a structure with types ' - '{!s}).'.format(type(structure), structure) - ) - - return flat_sequence[position], position + 1 - else: - elements = [] - for k, v in iter_elements(structure): - packed_v, position = _pack(v, flat_sequence, position) - elements.append((k, packed_v)) - return Struct(elements), position - - result, _ = _pack(structure, flat_sequence, 0) - # Note: trailing elements are currently ignored. - return result - - -def is_same_structure(a: Struct, b: Struct) -> bool: - """Compares whether `a` and `b` have the same nested structure. - - This method is analogous to `tf.nest.assert_same_structure`, - but returns a boolean rather than throwing an exception. - - Args: - a: a `Struct` object. - b: a `Struct` object. - - Returns: - True iff `a` and `b` have the same nested structure. - - Raises: - TypeError: if `a` or `b` are not of type `Struct`. - """ - elems_a = to_elements(a) - elems_b = to_elements(b) - if len(elems_a) != len(elems_b): - return False - for elem_a, elem_b in zip(elems_a, elems_b): - val_a = elem_a[1] - val_b = elem_b[1] - if elem_a[0] != elem_b[0]: - return False - if isinstance(val_a, Struct) and isinstance(val_b, Struct): - return is_same_structure(val_a, val_b) - elif isinstance(val_a, Struct) or isinstance(val_b, Struct): - return False - else: - try: - tree.assert_same_structure(val_a, val_b, check_types=True) - except (ValueError, TypeError): - return False - return True - - -def map_structure(fn: Callable[..., object], *structures: Struct) -> object: - """Applies `fn` to each entry in `structure` and returns a new structure. - - This is a special implementation of `tf.nest.map_structure` - that works for `Struct`. - - Args: - fn: a callable that accepts as many arguments as there are structures. - *structures: a scalar, tuple, or list of constructed scalars and/or - tuples/lists, or scalars. Note: numpy arrays are considered scalars. - - Returns: - A new structure with the same arity as `structure` and same type as - `structure[0]`, whose values correspond to `fn(x[0], x[1], ...)` where - `x[i]` is a value in the corresponding location in `structure[i]`. - - Raises: - TypeError: if `fn` is not a callable, or *structure is not all `Struct` or - all `tf.Tensor` typed values. - ValueError: if `*structure` is empty. - """ - if not structures: - raise ValueError('Must provide at least one structure') - - if not all(isinstance(s, Struct) for s in structures): - return tree.map_structure(fn, *structures) - - for i, other in enumerate(structures[1:]): - if not is_same_structure(structures[0], other): - raise TypeError( - 'Structure at position {} is not the same structure'.format(i) - ) - - flat_structure = [flatten(s) for s in structures] - entries = zip(*flat_structure) - s = [fn(*x) for x in entries] - - return pack_sequence_as(structures[0], s) - - -def from_container(value: object, recursive=False) -> Struct: - """Creates an instance of `Struct` from a Python container. - - By default, this conversion is only performed at the top level for Python - dictionaries, `collections.OrderedDict`s, `namedtuple`s, `list`s, - `tuple`s, and `attr.s` classes. Elements of these structures are not - recursively converted. - - Args: - value: _The Python container to convert. - recursive: Whether to convert elements recursively (`False` by default). - - Returns: - The corresponding instance of `Struct`. - - Raises: - TypeError: If the `value` is not of one of the supported container types. - """ - - def _convert( - value: object, recursive: bool, must_be_container: bool = False - ) -> Struct: - """The actual conversion function. - - Args: - value: Same as in `from_container`. - recursive: Same as in `from_container`. - must_be_container: When set to `True`, causes an exception to be raised if - `value` is not a container. - - Returns: - The result of conversion. - - Raises: - TypeError: If `value` is not a container and `must_be_container` has been - set to `True`. - """ - # TODO: b/224484886 - Downcasting to all handled types. - value = typing.cast( - Union[ - Struct, - py_typecheck.SupportsNamedTuple, - Mapping[str, object], - dict[str, object], - tuple[object, ...], - list[object], - ], - value, - ) - if isinstance(value, Struct): - if recursive: - return Struct((k, _convert(v, True)) for k, v in iter_elements(value)) - else: - return value - elif attrs.has(type(value)): - return _convert( - attrs.asdict(value, recurse=False), recursive, must_be_container - ) - elif isinstance(value, py_typecheck.SupportsNamedTuple): - return _convert(value._asdict(), recursive, must_be_container) - elif isinstance(value, Mapping): - items = value.items() - if recursive: - return Struct((k, _convert(v, True)) for k, v in items) - else: - return Struct(items) - elif isinstance(value, (tuple, list)): - if recursive: - return Struct((None, _convert(v, True)) for v in value) - else: - return Struct((None, v) for v in value) - elif must_be_container: - raise TypeError( - 'Unable to convert a Python object of type {} into ' - 'an `Struct`. Object: {}'.format( - py_typecheck.type_string(type(value)), value - ) - ) - else: - return value - - return _convert(value, recursive, must_be_container=True) - - -def _to_container_recursive( - value: Struct[_T], - container_fn: Callable[[list[tuple[Optional[str], _T]]], _U], -) -> _U: - """Recursively converts the `Struct` `value` to a new container type. - - This function is always recursive, since the non-recursive version would be - just `container_fn(value)`. - - Note: _This function will only recurse through `Struct`s, so if called - on the input `Struct([('a', 1), ('b', {'c': Struct(...)})])` - the inner `Struct` will not be converted, because we do not recurse - through Python `dict`s. - - Args: - value: An `Struct`, possibly nested. - container_fn: A function that takes a `list` of `(name, value)` tuples ( the - elements of an `Struct`), and returns a new container holding the same - values. - - Returns: - A nested container of the type returned by `container_fn`. - """ - - def recurse(v): - if isinstance(v, Struct): - return _to_container_recursive(v, container_fn) - else: - return v - - return container_fn([(k, recurse(v)) for k, v in iter_elements(value)]) - - -def has_field(structure: Struct, field: str) -> bool: - """Returns `True` if the `structure` has the `field`. - - Args: - structure: An instance of `Struct`. - field: A string, the field to test for. - """ - return field in structure._name_array # pylint: disable=protected-access - - -def name_to_index_map(structure: Struct) -> dict[str, int]: - """Returns map from names in `structure` to their indices. - - Args: - structure: An instance of `Struct`. - - Returns: - Mapping from names in `structure` to their indices. - """ - return structure._name_to_index # pylint: disable=protected-access - - -def update_struct(structure, **kwargs): - """Constructs a new `structure` with new values for fields in `kwargs`. - - This is a helper method for working structured objects in a functional manner. - This method will create a new structure where the fields named by keys in - `kwargs` replaced with the associated values. - - NOTE: _This method only works on the first level of `structure`, and does not - recurse in the case of nested structures. A field that is itself a structure - can be replaced with another structure. - - Args: - structure: _The structure with named fields to update. - **kwargs: _The list of key-value pairs of fields to update in `structure`. - - Returns: - A new instance of the same type of `structure`, with the fields named - in the keys of `**kwargs` replaced with the associated values. - - Raises: - KeyError: If kwargs contains a field that is not in structure. - TypeError: If structure is not a structure with named fields. - """ - if not isinstance( - structure, (Struct, Mapping, py_typecheck.SupportsNamedTuple) - ) and not attrs.has(type(structure)): - raise TypeError( - '`structure` must be a structure with named fields (e.g. ' - 'dict, attrs class, collections.namedtuple, ' - 'tff.structure.Struct), but found {}'.format(type(structure)) - ) - if isinstance(structure, Struct): - elements = [ - (k, v) if k not in kwargs else (k, kwargs.pop(k)) - for k, v in iter_elements(structure) - ] - if kwargs: - raise KeyError(f'`structure` does not contain fields named {kwargs}') - return Struct(elements) - elif isinstance(structure, py_typecheck.SupportsNamedTuple): - # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a - # regular `dict`, so we wrap here to get consistent types across Python - # version.s - dictionary = structure._asdict() - elif attrs.has(type(structure)): - dictionary = attrs.asdict(structure, recurse=False) - else: - for key in kwargs: - if key not in structure: - raise KeyError( - 'structure does not contain a field named "{!s}"'.format(key) - ) - # Create a copy to prevent mutation of the original `structure` - dictionary = type(structure)(**structure) - dictionary.update(kwargs) - if isinstance(structure, Mapping): - return dictionary - return type(structure)(**dictionary) +from federated_language.common_libs import structure + +Struct = structure.Struct +name_list = structure.name_list +name_list_with_nones = structure.name_list_with_nones +to_elements = structure.to_elements +iter_elements = structure.iter_elements +to_odict = structure.to_odict +to_odict_or_tuple = structure.to_odict_or_tuple +flatten = structure.flatten +pack_sequence_as = structure.pack_sequence_as +is_same_structure = structure.is_same_structure +map_structure = structure.map_structure +from_container = structure.from_container +has_field = structure.has_field +name_to_index_map = structure.name_to_index_map +update_struct = structure.update_struct diff --git a/tensorflow_federated/python/common_libs/structure_test.py b/tensorflow_federated/python/common_libs/structure_test.py deleted file mode 100644 index 992e5e5186..0000000000 --- a/tensorflow_federated/python/common_libs/structure_test.py +++ /dev/null @@ -1,636 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import operator - -from absl.testing import absltest -from absl.testing import parameterized -import attrs - -from tensorflow_federated.python.common_libs import structure - - -class StructTest(parameterized.TestCase): - - def test_new_named(self): - x = structure.Struct.named(a=1, b=4) - self.assertSequenceEqual(structure.to_elements(x), [('a', 1), ('b', 4)]) - - def test_new_unnamed(self): - x = structure.Struct.unnamed(1, 4) - self.assertSequenceEqual(structure.to_elements(x), [(None, 1), (None, 4)]) - - def test_construction_from_list(self): - v = [('a', 1), ('b', 2), (None, 3)] - x = structure.Struct(v) - self.assertSequenceEqual(structure.to_elements(x), v) - - def test_construction_from_tuple(self): - v = (('a', 1), ('b', 2), (None, 3)) - x = structure.Struct(v) - self.assertSequenceEqual(structure.to_elements(x), v) - - def test_construction_from_ordereddict(self): - v = collections.OrderedDict(a=1, b=2, c=3) - x = structure.Struct(v.items()) - self.assertSequenceEqual(structure.to_elements(x), list(v.items())) - - def test_construction_from_generator_expression(self): - x = structure.Struct((name, i) for i, name in enumerate(('a', 'b', None))) - self.assertSequenceEqual( - structure.to_elements(x), [('a', 0), ('b', 1), (None, 2)] - ) - - def test_construction_from_iter_elements(self): - x = structure.Struct((('a', 1), ('b', 2), (None, 3))) - self.assertSequenceEqual(structure.Struct(structure.iter_elements(x)), x) - - def test_empty(self): - v = [] - x = structure.Struct(v) - self.assertEmpty(x) - self.assertRaises(IndexError, lambda _: x[0], None) - self.assertEqual(list(iter(x)), []) - self.assertEqual(dir(x), []) - self.assertRaises(AttributeError, lambda _: x.foo, None) - self.assertEqual(x, structure.Struct([])) - self.assertNotEqual(x, structure.Struct([('foo', 10)])) - self.assertEqual(structure.to_elements(x), v) - self.assertEqual(structure.to_odict(x), collections.OrderedDict()) - self.assertEqual(structure.to_odict_or_tuple(x), ()) - self.assertEqual(repr(x), 'Struct([])') - self.assertEqual(str(x), '<>') - - def test_single_unnamed(self): - v = [(None, 10)] - x = structure.Struct(v) - self.assertLen(x, 1) - self.assertRaises(IndexError, lambda _: x[1], None) - self.assertEqual(x[0], 10) - self.assertEqual(list(iter(x)), [10]) - self.assertEqual(dir(x), []) - self.assertRaises(AttributeError, lambda _: x.foo, None) - self.assertNotEqual(x, structure.Struct([])) - self.assertNotEqual(x, structure.Struct([('foo', 10)])) - self.assertEqual(x, structure.Struct([(None, 10)])) - self.assertNotEqual(x, structure.Struct([(None, 10), ('foo', 20)])) - self.assertEqual(structure.to_elements(x), v) - self.assertEqual(repr(x), 'Struct([(None, 10)])') - self.assertEqual(str(x), '<10>') - self.assertEqual(structure.to_odict_or_tuple(x), tuple([10])) - with self.assertRaisesRegex(ValueError, 'unnamed'): - structure.to_odict(x) - - def test_single_named(self): - v = [('foo', 20)] - x = structure.Struct(v) - self.assertLen(x, 1) - self.assertEqual(x[0], 20) - self.assertRaises(IndexError, lambda _: x[1], None) - self.assertEqual(list(iter(x)), [20]) - self.assertEqual(dir(x), ['foo']) - self.assertEqual(x.foo, 20) - self.assertRaises(AttributeError, lambda _: x.bar, None) - self.assertNotEqual(x, structure.Struct([])) - self.assertNotEqual(x, structure.Struct([('foo', 10)])) - self.assertNotEqual(x, structure.Struct([(None, 20)])) - self.assertEqual(x, structure.Struct([('foo', 20)])) - self.assertNotEqual(x, structure.Struct([('foo', 20), ('bar', 30)])) - self.assertEqual(structure.to_elements(x), v) - self.assertEqual(repr(x), "Struct([('foo', 20)])") - self.assertEqual(str(x), '') - self.assertEqual(structure.to_odict(x), collections.OrderedDict(v)) - self.assertEqual(structure.to_odict_or_tuple(x), collections.OrderedDict(v)) - - def test_multiple_named_and_unnamed(self): - v = [(None, 10), ('foo', 20), ('bar', 30)] - x = structure.Struct(v) - self.assertLen(x, 3) - self.assertEqual(x[0], 10) - self.assertEqual(x[1], 20) - self.assertEqual(x[2], 30) - self.assertRaises(IndexError, lambda _: x[3], None) - self.assertEqual(list(iter(x)), [10, 20, 30]) - self.assertEqual(dir(x), ['bar', 'foo']) - self.assertEqual(structure.name_list(x), ['foo', 'bar']) - self.assertEqual(x.foo, 20) - self.assertEqual(x.bar, 30) - self.assertRaises(AttributeError, lambda _: x.baz, None) - self.assertEqual( - x, structure.Struct([(None, 10), ('foo', 20), ('bar', 30)]) - ) - self.assertNotEqual( - x, structure.Struct([('foo', 10), ('bar', 20), (None, 30)]) - ) - self.assertEqual(structure.to_elements(x), v) - self.assertEqual(repr(x), "Struct([(None, 10), ('foo', 20), ('bar', 30)])") - self.assertEqual(str(x), '<10,foo=20,bar=30>') - with self.assertRaisesRegex(ValueError, 'unnamed'): - structure.to_odict(x) - with self.assertRaisesRegex(ValueError, 'named and unnamed'): - structure.to_odict_or_tuple(x) - - def test_bad_names(self): - with self.assertRaisesRegex(ValueError, 'duplicated.*foo'): - structure.Struct([('foo', 20), ('foo', 30)]) - - with self.assertRaisesRegex(ValueError, '_asdict.*reserved'): - structure.Struct.named(_asdict=40) - - with self.assertRaisesRegex(ValueError, '_element_array.*reserved'): - structure.Struct.named(_element_array=40) - - with self.assertRaisesRegex(ValueError, '_name_to_index.*reserved'): - structure.Struct.named(_name_to_index=40) - - with self.assertRaisesRegex(ValueError, '_name_array.*reserved'): - structure.Struct.named(_name_array=40) - - with self.assertRaisesRegex(ValueError, '_hash.*reserved'): - structure.Struct.named(_hash=40) - - def test_immutable(self): - t = structure.Struct.named(foo='a string', bar=1, baz=[1.0, 2.0, 3.0]) - - # Expect that we can read by name the values. - self.assertEqual(t.foo, 'a string') - self.assertEqual(t[0], 'a string') - self.assertEqual(t.bar, 1) - self.assertEqual(t[1], 1) - self.assertEqual(t.baz, [1.0, 2.0, 3.0]) - self.assertEqual(t[2], [1.0, 2.0, 3.0]) - - # But trying to set an attribute fails. - - # These raise "AttributeError" saying that the particular attribute is - # unknown. This can look strange because the attribute was "known" above. - with self.assertRaises(AttributeError): - t.foo = 'a different string' - with self.assertRaises(AttributeError): - t.bar = 5 - with self.assertRaises(AttributeError): - t.baz = [1, 2, 3] - - # These raise "TypeError" saying that tuples are immutable. - with self.assertRaises(TypeError): - t[0] = 'a different string' - with self.assertRaises(TypeError): - t[1] = 5 - with self.assertRaises(TypeError): - t[2] = [1, 2, 3] - - def test_equality_unnamed(self): - # identity - t1 = structure.Struct([(None, 1), (None, 2)]) - self.assertTrue(t1.__eq__(t1)) - self.assertFalse(t1.__ne__(t1)) - # different type - self.assertIs(t1.__eq__(None), NotImplemented) - self.assertTrue(t1.__ne__(None)) - # copy - t2 = structure.Struct([(None, 1), (None, 2)]) - self.assertTrue(t1.__eq__(t2)) - self.assertTrue(t2.__eq__(t1)) - self.assertFalse(t1.__ne__(t2)) - self.assertFalse(t2.__ne__(t1)) - # different ordering - t3 = structure.Struct([(None, 2), (None, 1)]) - self.assertFalse(t1.__eq__(t3)) - self.assertFalse(t3.__eq__(t1)) - self.assertTrue(t1.__ne__(t3)) - self.assertTrue(t3.__ne__(t1)) - # different names - t4 = structure.Struct([('a', 1), ('b', 2)]) - self.assertFalse(t1.__eq__(t4)) - self.assertFalse(t4.__eq__(t1)) - self.assertTrue(t1.__ne__(t4)) - self.assertTrue(t4.__ne__(t1)) - # different values - t5 = structure.Struct([(None, 10), (None, 10)]) - self.assertFalse(t1.__eq__(t5)) - self.assertFalse(t5.__eq__(t1)) - self.assertTrue(t1.__ne__(t5)) - self.assertTrue(t5.__ne__(t1)) - - def test_equality_named(self): - # identity - t1 = structure.Struct.named(a=1, b=2) - self.assertTrue(t1.__eq__(t1)) - self.assertFalse(t1.__ne__(t1)) - # different type - self.assertIs(t1.__eq__(None), NotImplemented) - self.assertTrue(t1.__ne__(None)) - # copy - t2 = structure.Struct.named(a=1, b=2) - self.assertTrue(t1.__eq__(t2)) - self.assertTrue(t2.__eq__(t1)) - self.assertFalse(t1.__ne__(t2)) - self.assertFalse(t2.__ne__(t1)) - # different ordering - t3 = structure.Struct.named(b=2, a=1) - self.assertFalse(t1.__eq__(t3)) - self.assertFalse(t3.__eq__(t1)) - self.assertTrue(t1.__ne__(t3)) - self.assertTrue(t3.__ne__(t1)) - # different names - t4 = structure.Struct.named(c=1, d=2) - self.assertFalse(t1.__eq__(t4)) - self.assertFalse(t4.__eq__(t1)) - self.assertTrue(t1.__ne__(t4)) - self.assertTrue(t4.__ne__(t1)) - # different values - t5 = structure.Struct.named(a=10, b=10) - self.assertFalse(t1.__eq__(t5)) - self.assertFalse(t5.__eq__(t1)) - self.assertTrue(t1.__ne__(t5)) - self.assertTrue(t5.__ne__(t1)) - - def test_hash(self): - v1 = [(str(i) if i > 30 else None, i) for i in range(0, 50, 10)] - x1 = structure.Struct(v1) - self.assertNotEqual(x1, v1) - self.assertNotEqual(hash(x1), hash(iter(v1))) - v2 = [(None, i) for i in range(0, 50, 10)] - x2 = structure.Struct(v2) - self.assertNotEqual(hash(x2), hash(iter(v2))) - self.assertNotEqual(x1, x2) - self.assertNotEqual(hash(x1), hash(x2)) - v3 = [(None, 0), (None, 10), (None, 20), (None, 30), (None, 40)] - x3 = structure.Struct(v3) - self.assertEqual(v2, v3) - self.assertEqual(x2, x3) - self.assertEqual(hash(x2), hash(x3)) - - def test_slicing_behavior(self): - x = structure.Struct.unnamed(*tuple(range(0, 50, 10))) - self.assertEqual(x[:], tuple(range(0, 50, 10))) - self.assertEqual(x[::-1], tuple(reversed(range(0, 50, 10)))) - self.assertEqual(x[:-1], tuple(range(0, 40, 10))) - self.assertEqual(x[1:], tuple(range(10, 50, 10))) - self.assertEqual(x[-1:], (40,)) - - def test_getitem_key(self): - x = structure.Struct.named(foo=10, bar=20) - self.assertEqual(x['foo'], 10) - self.assertEqual(x['bar'], 20) - with self.assertRaises(AttributeError): - _ = x['badkey'] - - def test_getitem_key_builtin_attribute_raises(self): - x = structure.Struct.named(foo=10, bar=20) - with self.assertRaises(AttributeError): - _ = x['__getattr__'] - - def test_getitem_bad_bounds(self): - x = structure.Struct.unnamed(*tuple(range(0, 50, 10))) - with self.assertRaises(IndexError): - _ = x[10] - - def test_pack_sequence_as_fails_non_struct(self): - x = structure.Struct.named(a=10, b=dict(d=20), c=30) - y = [10, 20, 30] - with self.assertRaisesRegex(TypeError, 'Cannot pack sequence'): - _ = structure.pack_sequence_as(x, y) - - def test_flatten_and_pack_sequence_as(self): - x = structure.Struct.named( - a=10, - b=structure.Struct.named( - x=structure.Struct.named(p=40), - y=30, - z=structure.Struct.named(q=50, r=60), - ), - c=20, - ) - y = structure.flatten(x) - self.assertEqual(y, [10, 40, 30, 50, 60, 20]) - z = structure.pack_sequence_as(x, y) - self.assertEqual(str(z), ',y=30,z=>,c=20>') - - def test_is_same_structure_check_types(self): - self.assertTrue( - structure.is_same_structure( - structure.Struct.named(a=10), structure.Struct.named(a=20) - ) - ) - self.assertTrue( - structure.is_same_structure( - structure.Struct.named( - a=10, - b=structure.Struct.named(z=5), - ), - structure.Struct.named(a=20, b=structure.Struct.named(z=50)), - ) - ) - self.assertFalse( - structure.is_same_structure( - structure.Struct.named(x=dict(y=4)), - structure.Struct.named(x=dict(y=5, z=6)), - ) - ) - self.assertTrue( - structure.is_same_structure( - structure.Struct.named(x=dict(y=5)), - structure.Struct.named(x=dict(y=6)), - ) - ) - - def test_map_structure(self): - x = structure.Struct.named( - a=10, - b=structure.Struct.named( - x=structure.Struct.named(p=40), - y=30, - z=structure.Struct.named(q=50, r=60), - ), - c=20, - ) - y = structure.Struct.named( - a=1, - b=structure.Struct.named( - x=structure.Struct.named(p=4), - y=3, - z=structure.Struct.named(q=5, r=6), - ), - c=2, - ) - - self.assertEqual( - structure.map_structure(operator.add, x, y), - structure.Struct.named( - a=11, - b=structure.Struct.named( - x=structure.Struct.named(p=44), - y=33, - z=structure.Struct.named(q=55, r=66), - ), - c=22, - ), - ) - - def test_map_structure_fails_different_structures(self): - x = structure.Struct.named(a=10, c=20) - y = structure.Struct.named(a=30) - with self.assertRaises(TypeError): - structure.map_structure(operator.add, x, y) - x = structure.Struct.named(a=10) - y = structure.Struct.named(a=30, c=['a', 'b', 'c']) - with self.assertRaises(TypeError): - structure.map_structure(operator.add, x, y) - - def test_map_structure_non_structs(self): - x = 1 - y = 2 - result = structure.map_structure(operator.add, x, y) - self.assertEqual(result, 3) - - x = ['a', 'b', 'c'] - y = ['x', 'y', 'z'] - result = structure.map_structure(operator.add, x, y) - self.assertEqual(result, ['ax', 'by', 'cz']) - - def test_from_container_with_none(self): - with self.assertRaises(TypeError): - structure.from_container(None) - - def test_from_container_with_int(self): - with self.assertRaises(TypeError): - structure.from_container(10) - - def test_from_container_with_list(self): - x = structure.from_container([10, 20]) - self.assertIsInstance(x, structure.Struct) - self.assertEqual(str(x), '<10,20>') - - def test_from_container_with_tuple(self): - x = structure.from_container(tuple([10, 20])) - self.assertIsInstance(x, structure.Struct) - self.assertEqual(str(x), '<10,20>') - - def test_from_container_with_dict(self): - x = structure.from_container({'z': 10, 'y': 20, 'a': 30}) - self.assertIsInstance(x, structure.Struct) - self.assertEqual(str(x), '') - - def test_from_container_with_ordered_dict(self): - x = structure.from_container( - collections.OrderedDict([('z', 10), ('y', 20), ('a', 30)]) - ) - self.assertIsInstance(x, structure.Struct) - self.assertEqual(str(x), '') - - def test_from_container_with_namedtuple(self): - x = structure.from_container(collections.namedtuple('_', 'x y')(1, 2)) - self.assertIsInstance(x, structure.Struct) - self.assertEqual(str(x), '') - - def test_from_container_with_attrs_class(self): - - @attrs.define - class TestFoo: - x: int - y: int - - x = structure.from_container(TestFoo(1, 2)) - self.assertIsInstance(x, structure.Struct) - self.assertEqual(str(x), '') - - def test_from_container_with_struct(self): - x = structure.from_container(structure.Struct([('a', 10), ('b', 20)])) - self.assertIs(x, x) - - def test_from_container_with_namedtuple_of_odict_recursive(self): - x = structure.from_container( - collections.namedtuple('_', 'x y')( - collections.OrderedDict([('a', 10), ('b', 20)]), - collections.OrderedDict([('c', 30), ('d', 40)]), - ), - recursive=True, - ) - self.assertEqual(str(x), ',y=>') - - @parameterized.named_parameters( - ('empty', collections.OrderedDict()), - ('flat', collections.OrderedDict(a=1, b=2)), - ( - 'nested', - collections.OrderedDict( - a=1, - b=2, - c=collections.OrderedDict( - d=3, e=collections.OrderedDict(f=4, g=5) - ), - ), - ), - ) - def test_from_container_asdict_roundtrip(self, dict_in): - structure_repr = structure.from_container(dict_in, recursive=True) - dict_out = structure_repr._asdict(recursive=True) - self.assertEqual(dict_in, dict_out) - - def test_from_container_raises_on_non_container_argument(self): - with self.assertRaises(TypeError): - structure.from_container(3) - - def test_name_to_index_map_empty_unnamed_struct(self): - unnamed_struct = structure.Struct.unnamed(10, 20) - self.assertEmpty(structure.name_to_index_map(unnamed_struct)) - - def test_name_to_index_map_partially_named_struct(self): - partially_named_struct = structure.Struct([(None, 10), ('a', 20)]) - - name_to_index_dict = structure.name_to_index_map(partially_named_struct) - expected_name_to_index_map = {'a': 1} - self.assertEqual(name_to_index_dict, expected_name_to_index_map) - - def test_name_to_index_map_fully_named_struct(self): - partially_named_struct = structure.Struct.named(b=10, a=20) - - name_to_index_dict = structure.name_to_index_map(partially_named_struct) - expected_name_to_index_map = {'b': 0, 'a': 1} - self.assertEqual(name_to_index_dict, expected_name_to_index_map) - - def test_update_struct(self): - with self.subTest('fully_named'): - state = structure.Struct.named(a=1, b=2, c=3) - state = structure.update_struct(state, c=7) - self.assertEqual(state, structure.Struct.named(a=1, b=2, c=7)) - state = structure.update_struct(state, a=8) - self.assertEqual(state, structure.Struct.named(a=8, b=2, c=7)) - with self.subTest('partially_named'): - state = structure.Struct([(None, 1), ('b', 2), (None, 3)]) - state = structure.update_struct(state, b=7) - self.assertEqual( - state, structure.Struct([(None, 1), ('b', 7), (None, 3)]) - ) - with self.assertRaises(KeyError): - structure.update_struct(state, a=8) - with self.subTest('nested'): - state = structure.Struct.named(a=dict(a1=1, a2=2), b=2, c=3) - state = structure.update_struct(state, a=7) - self.assertEqual(state, structure.Struct.named(a=7, b=2, c=3)) - state = structure.update_struct(state, a=dict(foo=1, bar=2)) - self.assertEqual( - state, structure.Struct.named(a=dict(foo=1, bar=2), b=2, c=3) - ) - with self.subTest('unnamed'): - state = structure.Struct.unnamed(*tuple(range(3))) - with self.assertRaises(KeyError): - structure.update_struct(state, a=1) - with self.assertRaises(KeyError): - structure.update_struct(state, b=1) - - def test_update_struct_namedtuple(self): - my_tuple_type = collections.namedtuple('my_tuple_type', 'a b c') - state = my_tuple_type(1, 2, 3) - state2 = structure.update_struct(state, c=7) - self.assertEqual(state2, my_tuple_type(1, 2, 7)) - state3 = structure.update_struct(state2, a=8) - self.assertEqual(state3, my_tuple_type(8, 2, 7)) - - def test_update_struct_dict(self): - state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]) - state2 = structure.update_struct(state, c=7) - self.assertEqual(state2, {'a': 1, 'b': 2, 'c': 7}) - state3 = structure.update_struct(state2, a=8) - self.assertEqual(state3, {'a': 8, 'b': 2, 'c': 7}) - - def test_update_struct_on_dict_does_not_mutate_original(self): - state = collections.OrderedDict(a=1, b=2, c=3) - state2 = structure.update_struct(state, c=7) - del state2 - self.assertEqual(state, collections.OrderedDict(a=1, b=2, c=3)) - - def test_update_struct_ordereddict(self): - state = collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]) - state2 = structure.update_struct(state, c=7) - self.assertEqual( - state2, collections.OrderedDict([('a', 1), ('b', 2), ('c', 7)]) - ) - state3 = structure.update_struct(state2, a=8) - self.assertEqual( - state3, collections.OrderedDict([('a', 8), ('b', 2), ('c', 7)]) - ) - - def test_update_struct_attrs(self): - - @attrs.define - class TestAttrsClass: - a: int - b: int - c: int - - state = TestAttrsClass(1, 2, 3) - state2 = structure.update_struct(state, c=7) - self.assertEqual(state2, TestAttrsClass(1, 2, 7)) - state3 = structure.update_struct(state2, a=8) - self.assertEqual(state3, TestAttrsClass(8, 2, 7)) - - def test_update_struct_fails(self): - with self.assertRaisesRegex(TypeError, '`structure` must be a structure'): - structure.update_struct((1, 2, 3), a=8) - with self.assertRaisesRegex(TypeError, '`structure` must be a structure'): - structure.update_struct([1, 2, 3], a=8) - with self.assertRaisesRegex(KeyError, 'does not contain a field'): - structure.update_struct({'z': 1}, a=8) - - @parameterized.named_parameters( - ('empty_tuple', ()), - ('flat_tuple', (1, 2)), - ('nested_tuple', (1, 2, (3, (4, 5)))), - ('flat_dict', collections.OrderedDict(a=1, b=2)), - ( - 'nested_dict', - collections.OrderedDict( - a=1, - b=2, - c=collections.OrderedDict( - d=3, e=collections.OrderedDict(f=4, g=5) - ), - ), - ), - ( - 'mixed', - collections.OrderedDict( - a=1, b=2, c=(3, collections.OrderedDict(d=4, e=5)) - ), - ), - ) - def test_to_odict_or_tuple_from_container_roundtrip(self, original): - structure_repr = structure.from_container(original, recursive=True) - out = structure.to_odict_or_tuple(structure_repr) - self.assertEqual(original, out) - - def test_to_odict_or_tuple_empty_dict_becomes_empty_tuple(self): - s = collections.OrderedDict() - x = structure.from_container(s) - self.assertEqual(structure.to_odict_or_tuple(x), ()) - - def test_to_odict_or_tuple_mixed_nonrecursive(self): - s = collections.OrderedDict( - a=1, b=2, c=(3, collections.OrderedDict(d=4, e=5)) - ) - x = structure.from_container(s, recursive=False) - self.assertEqual(s, structure.to_odict_or_tuple(x, recursive=False)) - - def test_to_odict_or_tuple_raises_on_mixed_named_and_unnamed(self): - s = [(None, 10), ('foo', 20), ('bar', 30)] - x = structure.Struct(s) - with self.assertRaisesRegex(ValueError, 'named and unnamed'): - structure.to_odict_or_tuple(x) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/common_libs/tracing.py b/tensorflow_federated/python/common_libs/tracing.py deleted file mode 100644 index a2c8f09701..0000000000 --- a/tensorflow_federated/python/common_libs/tracing.py +++ /dev/null @@ -1,455 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility functions for instrumenting code with timing and tracing data. - -This module provides several functions for preserving trace context across -various boundaries, namely between asyncio and regular python code: - - * wrap_coroutine_in_trace_context wraps a coroutine such that it - inherits the ambient trace context. It should be used when executing a - coroutine that should inherit trace context from the current thread or - task. - * EventLoops should use the Task factory provided by - propagate_trace_context_task_factory by calling - `set_task_factory(propagate_trace_context_task_factory)`. -""" - -import abc -import asyncio -from collections.abc import Generator -import contextlib -import functools -import inspect -import random -import sys -import threading -import time -from typing import Generic, Optional, TypeVar, Union - -from absl import logging - -from tensorflow_federated.python.common_libs import py_typecheck - - -class TracedSpan: - """The trace was wrapping a non-function span. - - This value will be given back from `TracingProvider::span`'s first `yield` - if the trace was being used to wrap a `span` rather than a whole function. - """ - - pass - - -class TracedFunctionReturned: - """The traced function returned successfully. - - This value will be given back from `TracingProvider::span`'s first `yield` - if the function being traced returned normally. The return value will be kept - in the `value` field. - """ - - def __init__(self, value): - self.value = value - - -class TracedFunctionThrew: - """The traced function threw an exception. - - This value will be given back from `TracingProvider::span`'s first `yield` - if the function being traced threw an exception. - """ - - def __init__(self, error_type, error_value, traceback): - self.error_type = error_type - self.error_value = error_value - self.traceback = traceback - - -TraceResult = Union[TracedSpan, TracedFunctionReturned, TracedFunctionThrew] - -T = TypeVar('T') - - -class TracingProvider(Generic[T], metaclass=abc.ABCMeta): - """Abstract base class for tracers.""" - - @abc.abstractmethod - def span( - self, - scope: str, - sub_scope: str, - nonce: int, - parent_span_yield: Optional[T], - fn_args: Optional[tuple[object, ...]], - fn_kwargs: Optional[dict[str, object]], - trace_opts: dict[str, object], - ) -> Generator[T, TraceResult, None]: - """Create a new tracing span. - - Args: - scope: String name of the scope, often the class name. - sub_scope: String name of the sub-scope, often the function name. - nonce: Number used to correlate tracing messages relating to the same - function invocation. - parent_span_yield: The value yielded by the most recently started (and not - exited) call to `span` on this `TracingProvider` on the current - `asyncio.Task` or thread (when running outside of an async context). - fn_args: When this tracing provider wraps a function, this will be a tuple - containing all of the non-keyword function arguments. - fn_kwargs: When this tracing provider wraps a function, this will be a - dict containing all of the keyword function arguments. - trace_opts: User-provided options to the span constructor. - `TracingProvider`s should ignore unknown options. - - Returns: - A `Generator` which will be immediately started and run up until it - yields for the first time. The value yielded by this `Generator` - will be passed on to nested calls to `span`. When the spanned code ends, - a `TraceResult` will be passed back through the `yield`. - """ - raise NotImplementedError - - def wrap_rpc( - self, parent_span_yield: Optional[T] - ) -> contextlib.AbstractContextManager[None]: - """Wrap an RPC call so that it can carry over the `parent_span_yield`.""" - del parent_span_yield - return contextlib.nullcontext() - - def receive_rpc(self) -> Optional[T]: - """Unpack `parent_span_yield` from the receiving end of an RPC.""" - return None - - -class LoggingTracingProvider(TracingProvider[None]): - """Implements TracingProvider and outputs the results via logging. - - This implementation does not require storing additional trace context state, - so most methods are no-ops. - """ - - def span( - self, - scope: str, - sub_scope: str, - nonce: int, - parent_span_yield: Optional[None], - fn_args: Optional[tuple[object, ...]], - fn_kwargs: Optional[dict[str, object]], - trace_opts: dict[str, object], - ) -> Generator[None, TraceResult, None]: - assert parent_span_yield is None - del parent_span_yield, fn_args, fn_kwargs, trace_opts - start_time = time.time() - logging.debug('(%s) Entering %s.%s', nonce, scope, sub_scope) - yield None - logging.debug( - '(%s) Exiting %s.%s. Elapsed time %f', - nonce, - scope, - sub_scope, - time.time() - start_time, - ) - - -_global_tracing_providers = [LoggingTracingProvider()] - - -def trace(fn=None, **trace_kwargs): - """Delegates to the current global `TracingProvider`. - - Note that this function adds a layer of indirection so that the decoration - happens when the method is executed. This is necessary so that the current - TracingProvider is used. - - Args: - fn: Function to decorate. - **trace_kwargs: Tracing options. Supported options differ by tracing - provider. - - Returns: - Decorated instance of fn. - """ - if fn is None: - return functools.partial(trace, **trace_kwargs) - - scope, sub_scope = _func_to_class_and_method(fn) - - # Note: in a classic "what color is your function" situation, - # we unfortunately have to duplicate the wrapping of the - # underlying function in order to cover both the sync and async cases. - if inspect.iscoroutinefunction(fn): - - @functools.wraps(fn) - async def async_trace(*fn_args, **fn_kwargs): - # Produce the span generator - span_gen = _span_generator( - scope, sub_scope, trace_kwargs, fn_args=fn_args, fn_kwargs=fn_kwargs - ) - # Run up until the first yield - next(span_gen) - completed = False - # Run the underlying function, recording the resulting value or exception - # and passing it back to the span generator - try: - result = await fn(*fn_args, **fn_kwargs) - completed = True - try: - span_gen.send(TracedFunctionReturned(result)) - except StopIteration: - pass - return result - except: - if not completed: - error_type, error_value, traceback = sys.exc_info() - try: - span_gen.send( - TracedFunctionThrew(error_type, error_value, traceback) - ) - except StopIteration: - pass - raise - - return async_trace - else: - - @functools.wraps(fn) - def sync_trace(*fn_args, **fn_kwargs): - span_gen = _span_generator( - scope, sub_scope, trace_kwargs, fn_args=fn_args, fn_kwargs=fn_kwargs - ) - next(span_gen) - completed = False - try: - result = fn(*fn_args, **fn_kwargs) - completed = True - try: - span_gen.send(TracedFunctionReturned(result)) - except StopIteration: - pass - return result - except: - if not completed: - error_type, error_value, traceback = sys.exc_info() - try: - span_gen.send( - TracedFunctionThrew(error_type, error_value, traceback) - ) - except StopIteration: - pass - raise - - return sync_trace - - -# The code below manages the active "span yields" for a task or thread. -# Here's a quick summary of how that works. -# -# A "span yield" is a value `yield`ed by the `TracingProvider.span` function. -# The span yields for the current encompassing span need to be tracked so that -# they can be passed to new calls to `span` as the `parent_span_yield`. -# -# Typically, these would be tracked with a thread-local. However, async tasks -# can interleave on a single thread, so it makes more sense for them to track -# "task locals". -# -# `_current_span_yields` and `_set_span_yields` below handle the logic of -# tracking these spans. If we're in an async context, they'll read and write -# to the current async tasks, but fall back to using a thread local if we're -# in a synchronous context. - -# A single yielded value for each currently-active TracingProvider. -SpanYields = list[Optional[object]] - - -class ThreadLocalSpanYields(threading.local): - """The span set for the current thread. - - This is only used when outside of an async context. - """ - - def __init__(self): - super().__init__() - self._span_yields: Optional[SpanYields] = None - - def set(self, span_yields: Optional[SpanYields]): - self._span_yields = span_yields - - def get(self) -> Optional[SpanYields]: - return self._span_yields - - -_non_async_span_yields = ThreadLocalSpanYields() - - -def _current_task() -> Optional[asyncio.Task]: - """Get the current running task, or `None` if no task is running.""" - # Note: `current_task` returns `None` if there is no current task, but it - # throws if no currently running async loop. - try: - return asyncio.current_task() - except RuntimeError: - return None - - -def _current_span_yields() -> SpanYields: - """Returns the current parent span yield list.""" - task = _current_task() - if task is None: - # There is no current task, so we're not running in an async context. - # Grab the spans from the current thread. - spans = _non_async_span_yields.get() - else: - spans = getattr(task, 'trace_span_yields', None) - if spans is None: - spans = [None for _ in range(len(_global_tracing_providers))] - assert len(_global_tracing_providers) == len(spans) - return spans - - -def _set_span_yields(span_yields: Optional[SpanYields]): - """Sets the current parent span list.""" - task = _current_task() - if task is None: - # There is no current task, so we're not running in an async context. - # Set the spans for the current thread. - _non_async_span_yields.set(span_yields) - else: - setattr(task, 'trace_span_yields', span_yields) - - -@contextlib.contextmanager -def _with_span_yields(span_yields: Optional[SpanYields]): - """Context manager which sets and unsets the current parent span list.""" - old_span_yields = _current_span_yields() - _set_span_yields(span_yields) - yield None - _set_span_yields(old_span_yields) - - -@contextlib.contextmanager -def span(scope, sub_scope, **trace_opts): - """Creates a `ContextManager` that wraps the code in question with a span.""" - span_gen = _span_generator(scope, sub_scope, trace_opts) - next(span_gen) - yield - try: - span_gen.send(TracedSpan()) - except StopIteration: - pass - - -def _span_generator( - scope, sub_scope, trace_opts, fn_args=None, fn_kwargs=None -) -> Generator[None, TraceResult, None]: - """Wraps up all the `TracingProvider.span` generators into one.""" - # Create a nonce so that all of the traces from this span can be associated - # with one another. - nonce = random.randrange(1000000000) - # Call `span` on all the global `TraceProvider`s and run it up until `yield`. - span_generators = [] - new_span_yields: SpanYields = [] - for tp, parent_span_yield in zip( - _global_tracing_providers, _current_span_yields() - ): - new_span_gen = tp.span( - scope, - sub_scope, - nonce, - parent_span_yield, - fn_args, - fn_kwargs, - trace_opts, - ) - new_span_yield = next(new_span_gen) - span_generators.append(new_span_gen) - new_span_yields.append(new_span_yield) - # Set the values yielded by the `span` calls above to be the current span - # yields, and yield so that the function can be run to completion. - with _with_span_yields(new_span_yields): - result = yield None - # Send the result of the function to all of the generators so that they can - # complete. - for span_gen in reversed(span_generators): - try: - span_gen.send(result) - except StopIteration: - pass - - -def propagate_trace_context_task_factory(loop, coro): - """Creates a new task on `loop` to run `coro`, inheriting current spans.""" - child_task = asyncio.tasks.Task(coro, loop=loop) - trace_span_yields = _current_span_yields() - setattr(child_task, 'trace_span_yields', trace_span_yields) - return child_task - - -def wrap_coroutine_in_current_trace_context(coro): - """Wraps the coroutine in the currently active span.""" - trace_span_yields = _current_span_yields() - - async def _wrapped(): - with _with_span_yields(trace_span_yields): - return await coro - - return _wrapped() - - -@contextlib.contextmanager -def wrap_rpc_in_trace_context(): - """Attempts to record the trace context into the enclosed RPC call.""" - with contextlib.ExitStack() as stack: - for tp, parent_span_yield in zip( - _global_tracing_providers, _current_span_yields() - ): - stack.enter_context(tp.wrap_rpc(parent_span_yield)) - yield None - - -@contextlib.contextmanager -def with_trace_context_from_rpc(): - """Attempts to pick up the trace context from the receiving RPC call.""" - span_yields_from_rpc = [tp.receive_rpc() for tp in _global_tracing_providers] - with _with_span_yields(span_yields_from_rpc): - yield None - - -def add_tracing_provider(tracing_provider: TracingProvider): - """Add to the global list of tracing providers.""" - py_typecheck.check_type(tracing_provider, TracingProvider) - _global_tracing_providers.append(tracing_provider) - - -def set_tracing_providers(tracing_providers: list[TracingProvider]): - """Set the global list of tracing providers, replacing any existing.""" - py_typecheck.check_type(tracing_providers, list) - for tp in tracing_providers: - py_typecheck.check_type(tp, TracingProvider) - global _global_tracing_providers - _global_tracing_providers = tracing_providers - - -def _func_to_class_and_method(fn) -> tuple[str, str]: - """Returns the names of the function's class and method.""" - split = fn.__qualname__.split('.') - if len(split) >= 2: - class_name = split[-2] - method_name = split[-1] - else: - module_name = fn.__module__ - class_name = module_name.split('.')[-1] - method_name = fn.__name__ - return class_name, method_name diff --git a/tensorflow_federated/python/common_libs/tracing_test.py b/tensorflow_federated/python/common_libs/tracing_test.py deleted file mode 100644 index e6294b3c38..0000000000 --- a/tensorflow_federated/python/common_libs/tracing_test.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import functools -import io -import logging as std_logging -import threading -import time - -from absl import logging -from absl.testing import absltest - -from tensorflow_federated.python.common_libs import tracing - -# Traces may not run in _exactly_ one second, but we can assert it was at least -# one second; and most importantly the time should be logged. -ELAPSED_ONE_REGEX = r'Elapsed time [1-9][0-9]*\.[0-9]+' - - -class DebugLoggingTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.log = io.StringIO() - self.handler = std_logging.StreamHandler(self.log) - std_logging.root.addHandler(self.handler) - - def tearDown(self): - std_logging.root.removeHandler(self.handler) - self.handler.close() - super().tearDown() - - def _test_debug_logging_with_async_function( - self, async_fn, test_regex, *args, **kwargs - ): - try: - logging.set_verbosity(1) - retval = asyncio.run(async_fn(*args, **kwargs)) - finally: - logging.set_verbosity(0) - self.assertRegexMatch(''.join(self.log.getvalue()), [test_regex]) - self.log.truncate(0) - asyncio.run(async_fn(*args, **kwargs)) - self.assertEmpty(''.join(self.log.getvalue())) - return retval - - def _test_debug_logging_with_sync_function( - self, sync_fn, test_regex, *args, **kwargs - ): - try: - logging.set_verbosity(1) - retval = sync_fn(*args, **kwargs) - finally: - logging.set_verbosity(0) - self.assertRegexMatch(''.join(self.log.getvalue()), [test_regex]) - self.log.truncate(0) - self.assertEmpty(''.join(self.log.getvalue())) - return retval - - def test_logging_enter_exit(self): - @tracing.trace - async def foo(): - return await asyncio.sleep(1) - - self._test_debug_logging_with_async_function( - foo, '.*Entering .*foo.*\n.*Exiting .*foo.*' - ) - - def test_logging_timing_captured(self): - @tracing.trace - async def foo(): - return await asyncio.sleep(1) - - self._test_debug_logging_with_async_function(foo, 'Elapsed time') - - def test_logging_timing_captures_value_around_async_call(self): - @tracing.trace - async def foo(): - return await asyncio.sleep(1) - - self._test_debug_logging_with_async_function( - foo, r'\.foo\. ' + ELAPSED_ONE_REGEX - ) - - def test_logging_non_blocking_function(self): - @tracing.trace(span=True) - async def foo(): - return await asyncio.gather( - asyncio.sleep(1), asyncio.sleep(1), asyncio.sleep(1) - ) - - self._test_debug_logging_with_async_function( - foo, r'\.foo\. ' + ELAPSED_ONE_REGEX - ) - - def test_logging_non_blocking_method(self): - class AClass(absltest.TestCase): - - @tracing.trace(span=True) - async def async_method(self, foo_arg, bar_arg, arg3=None, arg4=None): - self.assertEqual('foo', foo_arg) - self.assertEqual('bar', bar_arg) - self.assertIsNotNone(arg3) - self.assertIsNotNone(arg4) - await asyncio.sleep(1) - return 3 - - a_class = AClass() - - result = self._test_debug_logging_with_async_function( - a_class.async_method, - # Non-blocking may not run exactly one second, but we can assert it was - # at least one second; and most importantly it should be logged. - r'AClass\.async_method\. ' + ELAPSED_ONE_REGEX, - 'foo', - 'bar', - arg3='baz', - arg4=True, - ) - self.assertEqual(3, result) - - def test_logging_blocking_method(self): - class AClass(absltest.TestCase): - - @tracing.trace(span=True) - def sync_method(self, foo_arg, bar_arg, arg3=None, arg4=None): - self.assertEqual('foo', foo_arg) - self.assertEqual('bar', bar_arg) - self.assertIsNotNone(arg3) - self.assertIsNotNone(arg4) - # Sleep for 1s is used to test that we measured runtime correctly - time.sleep(1) - return 3 - - a_class = AClass() - - result = self._test_debug_logging_with_sync_function( - a_class.sync_method, - r'AClass\.sync_method\. ' + ELAPSED_ONE_REGEX, - 'foo', - 'bar', - arg3='baz', - arg4=True, - ) - self.assertEqual(3, result) - - def test_logging_blocking_function(self): - @tracing.trace(span=True) - def foo(foo_arg, bar_arg, arg3=None, arg4=None): - self.assertEqual('foo', foo_arg) - self.assertEqual('bar', bar_arg) - self.assertIsNotNone(arg3) - self.assertIsNotNone(arg4) - # Sleep for 1s is used to test that we measured runtime correctly - time.sleep(1) - return 3 - - result = self._test_debug_logging_with_sync_function( - foo, - r'\.foo\. ' + ELAPSED_ONE_REGEX, - 'foo', - 'bar', - arg3='baz', - arg4=True, - ) - self.assertEqual(3, result) - - -class MockTracingProvider(tracing.TracingProvider): - - def __init__(self): - self.scopes = [] - self.sub_scopes = [] - self.nonces = [] - self.parent_span_yields = [] - self.fn_argss = [] - self.fn_kwargss = [] - self.trace_optss = [] - self.trace_results = [] - - def span( - self, - scope, - sub_scope, - nonce, - parent_span_yield, - fn_args, - fn_kwargs, - trace_opts, - ): - self.scopes.append(scope) - self.sub_scopes.append(sub_scope) - self.nonces.append(nonce) - self.parent_span_yields.append(parent_span_yield) - self.fn_argss.append(fn_args) - self.fn_kwargss.append(fn_kwargs) - self.trace_optss.append(trace_opts) - if parent_span_yield is None: - new_yield = 0 - else: - new_yield = parent_span_yield + 1 - result = yield new_yield - self.trace_results.append(result) - - -def set_mock_trace() -> MockTracingProvider: - mock = MockTracingProvider() - tracing.set_tracing_providers([mock]) - return mock - - -class TracingProviderInterfaceTest(absltest.TestCase): - - def test_basic_span(self): - mock = set_mock_trace() - with tracing.span('scope', 'sub_scope', options='some_option'): - pass - self.assertEqual(mock.scopes[0], 'scope') - self.assertEqual(mock.sub_scopes[0], 'sub_scope') - self.assertIsNone(mock.parent_span_yields[0]) - self.assertIsNone(mock.fn_argss[0]) - self.assertIsNone(mock.fn_kwargss[0]) - self.assertEqual(mock.trace_optss[0], {'options': 'some_option'}) - self.assertIsInstance(mock.trace_results[0], tracing.TracedSpan) - - def test_sibling_spans(self): - mock = set_mock_trace() - with tracing.span('parent', ''): - with tracing.span('child1', ''): - pass - with tracing.span('child2', ''): - pass - with tracing.span('parentless', ''): - pass - - self.assertEqual(mock.scopes, ['parent', 'child1', 'child2', 'parentless']) - self.assertEqual(mock.parent_span_yields, [None, 0, 0, None]) - - def test_nested_non_async_span(self): - mock = set_mock_trace() - with tracing.span('outer', 'osub'): - with tracing.span('middle', 'msub'): - with tracing.span('inner', 'isub'): - pass - self.assertEqual(mock.scopes, ['outer', 'middle', 'inner']) - self.assertEqual(mock.sub_scopes, ['osub', 'msub', 'isub']) - self.assertEqual(mock.parent_span_yields, [None, 0, 1]) - - def test_basic_trace(self): - mock = set_mock_trace() - - class MyClass: - - @tracing.trace(options='some_option') - def my_func(self, a, b, kw=None): - del a, b, kw - return 5 - - obj = MyClass() - obj.my_func(1, 2, kw=3) - self.assertEqual(mock.scopes[0], 'MyClass') - self.assertEqual(mock.sub_scopes[0], 'my_func') - self.assertIsNone(mock.parent_span_yields[0]) - self.assertEqual(mock.fn_argss[0], (obj, 1, 2)) - self.assertEqual(mock.fn_kwargss[0], {'kw': 3}) - self.assertEqual(mock.trace_optss[0], {'options': 'some_option'}) - self.assertIsInstance(mock.trace_results[0], tracing.TracedFunctionReturned) - self.assertEqual(mock.trace_results[0].value, 5) - - def test_trace_throws(self): - mock = set_mock_trace() - - class MyClass: - - @tracing.trace - def my_func(self): - raise ValueError(5) - - obj = MyClass() - with self.assertRaises(ValueError): - obj.my_func() - - self.assertIsInstance(mock.trace_results[0], tracing.TracedFunctionThrew) - self.assertEqual(mock.trace_results[0].error_type, ValueError) - self.assertIsInstance(mock.trace_results[0].error_value, ValueError) - - def test_parenting_non_async_to_async_to_nested_async(self): - mock = set_mock_trace() - loop = asyncio.new_event_loop() - loop.set_task_factory(tracing.propagate_trace_context_task_factory) - - def run_loop(): - loop.run_forever() - loop.close() - - thread = threading.Thread(target=functools.partial(run_loop), daemon=True) - thread.start() - - @tracing.trace - async def middle(): - with tracing.span('inner', ''): - pass - - with tracing.span('outer', ''): - # This sends the coroutine over to another thread, - # keeping the current trace context. - coro_with_trace_ctx = tracing.wrap_coroutine_in_current_trace_context( - middle() - ) - asyncio.run_coroutine_threadsafe(coro_with_trace_ctx, loop).result() - - loop.call_soon_threadsafe(loop.stop) - thread.join() - - self.assertEqual(mock.parent_span_yields, [None, 0, 1]) - self.assertEqual(mock.scopes, ['outer', '', 'inner']) - self.assertEqual(mock.sub_scopes, ['', 'middle', '']) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/backends/mapreduce/BUILD b/tensorflow_federated/python/core/backends/mapreduce/BUILD index 515223e751..bd2a47a2ac 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/BUILD +++ b/tensorflow_federated/python/core/backends/mapreduce/BUILD @@ -42,16 +42,9 @@ py_library( "//tensorflow_federated/python/core/environments/tensorflow_backend:compiled_computation_transformations", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", "//tensorflow_federated/python/core/impl/compiler:transformations", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", "//tensorflow_federated/python/core/impl/compiler:tree_transformations", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", + "@federated_language//federated_language", ], ) @@ -62,24 +55,11 @@ py_test( ":compiler", ":form_utils", ":mapreduce_test_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_computation_factory", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", "//tensorflow_federated/python/core/impl/compiler:building_block_test_utils", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:set_default_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:executor_factory", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -89,14 +69,7 @@ py_library( deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:typed_object", + "@federated_language//federated_language", ], ) @@ -108,10 +81,7 @@ py_test( ":forms", ":mapreduce_test_utils", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -126,20 +96,9 @@ py_library( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_building_block_factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", "//tensorflow_federated/python/core/impl/compiler:transformations", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", "//tensorflow_federated/python/core/impl/compiler:tree_transformations", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -156,18 +115,9 @@ py_test( "//tensorflow_federated/python/core/backends/test:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_backend:serialization_utils", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", "//tensorflow_federated/python/core/impl/compiler:tree_transformations", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:iterative_process", + "@federated_language//federated_language", ], ) @@ -178,15 +128,8 @@ py_library( deps = [ ":forms", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", "//tensorflow_federated/python/core/impl/compiler:tree_transformations", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/federated_context:value_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -197,18 +140,7 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_computation_factory", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/context_stack:symbol_binding_context", - "//tensorflow_federated/python/core/impl/federated_context:value_impl", - "//tensorflow_federated/python/core/impl/federated_context:value_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_conversions", + "@federated_language//federated_language", ], ) @@ -225,15 +157,7 @@ py_test( deps = [ ":intrinsics", "//tensorflow_federated/python/common_libs:golden", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_test_utils", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation_context", - "//tensorflow_federated/python/core/impl/federated_context:value_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -244,15 +168,8 @@ py_library( deps = [ ":forms", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", "//tensorflow_federated/python/core/impl/compiler:tree_transformations", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/backends/mapreduce/compiler.py b/tensorflow_federated/python/core/backends/mapreduce/compiler.py index b4455348e2..9d5f1d03da 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/compiler.py +++ b/tensorflow_federated/python/core/backends/mapreduce/compiler.py @@ -63,6 +63,7 @@ """ from absl import logging +import federated_language import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck @@ -70,16 +71,8 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import compiled_computation_transformations from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import transformation_utils from tensorflow_federated.python.core.impl.compiler import transformations -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis class MapReduceFormCompilationError(Exception): @@ -89,34 +82,41 @@ class MapReduceFormCompilationError(Exception): def check_extraction_result(before_extraction, extracted): """Checks parsing TFF to TF has constructed an object of correct type.""" py_typecheck.check_type( - before_extraction, building_blocks.ComputationBuildingBlock + before_extraction, federated_language.framework.ComputationBuildingBlock + ) + py_typecheck.check_type( + extracted, federated_language.framework.ComputationBuildingBlock ) - py_typecheck.check_type(extracted, building_blocks.ComputationBuildingBlock) if isinstance( - before_extraction.type_signature, computation_types.FunctionType + before_extraction.type_signature, federated_language.FunctionType ): - if not isinstance(extracted, building_blocks.CompiledComputation): + if not isinstance( + extracted, federated_language.framework.CompiledComputation + ): raise MapReduceFormCompilationError( - 'We expect to parse down to a `building_blocks.CompiledComputation`, ' - 'since we have the functional type {} after unwrapping placement. ' - 'Instead we have the computation {} of type {}'.format( + 'We expect to parse down to a' + ' `federated_language.framework.CompiledComputation`, since we have' + ' the functional type {} after unwrapping placement. Instead we have' + ' the computation {} of type {}'.format( before_extraction.type_signature, extracted, extracted.type_signature, ) ) else: - if not isinstance(extracted, building_blocks.Call): + if not isinstance(extracted, federated_language.framework.Call): raise MapReduceFormCompilationError( - 'We expect to parse down to a `building_blocks.Call`, since we have ' - 'the non-functional type {} after unwrapping placement. Instead we ' - 'have the computation {} of type {}'.format( + 'We expect to parse down to a `federated_language.framework.Call`,' + ' since we have the non-functional type {} after unwrapping' + ' placement. Instead we have the computation {} of type {}'.format( before_extraction.type_signature, extracted, extracted.type_signature, ) ) - if not isinstance(extracted.function, building_blocks.CompiledComputation): + if not isinstance( + extracted.function, federated_language.framework.CompiledComputation + ): raise MapReduceFormCompilationError( 'We expect to parse a computation of the non-functional type {} down ' 'to a called TensorFlow block. Instead we hav a call to the ' @@ -185,8 +185,9 @@ def consolidate_and_extract_local_processing(comp, grappler_config_proto): this helper method. 4. If `comp` is of a functional type, it is either an instance of - `building_blocks.CompiledComputation`, in which case there is nothing for - us to do here, or a `building_blocks.Lambda`. + `federated_language.framework.CompiledComputation`, in which case there is + nothing for + us to do here, or a `federated_language.framework.Lambda`. 5. There is at most one unbound reference under `comp`, and this is only allowed in the case that `comp` is not of a functional type. @@ -236,8 +237,8 @@ def consolidate_and_extract_local_processing(comp, grappler_config_proto): `(T -> U)`, where `p` is again a specific placement. Args: - comp: An instance of `building_blocks.ComputationBuildingBlock` that serves - as the input to this transformation, as described above. + comp: An instance of `federated_language.framework.ComputationBuildingBlock` + that serves as the input to this transformation, as described above. grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the generated TensorFlow graph. If `grappler_config_proto` has @@ -245,11 +246,14 @@ def consolidate_and_extract_local_processing(comp, grappler_config_proto): bypassed. Returns: - An instance of `building_blocks.CompiledComputation` that holds the + An instance of `federated_language.framework.CompiledComputation` that holds + the TensorFlow section produced by this extraction step, as described above. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - if not isinstance(comp.type_signature, computation_types.FunctionType): + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) + if not isinstance(comp.type_signature, federated_language.FunctionType): raise ValueError( f'Expected a `tff.FunctionType`, found {comp.type_signature}.' ) @@ -265,22 +269,26 @@ def consolidate_and_extract_local_processing(comp, grappler_config_proto): def unpack_compiled_computations( - comp: building_blocks.ComputationBuildingBlock, -) -> building_blocks.ComputationBuildingBlock: + comp: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.ComputationBuildingBlock: """Deserializes compiled computations into building blocks where possible.""" def _unpack(subcomp): - if not isinstance(subcomp, building_blocks.CompiledComputation): + if not isinstance( + subcomp, federated_language.framework.CompiledComputation + ): return subcomp, False kind = subcomp.proto.WhichOneof('computation') if kind == 'tensorflow' or kind == 'xla': return subcomp, False return ( - building_blocks.ComputationBuildingBlock.from_proto(subcomp.proto), + federated_language.framework.ComputationBuildingBlock.from_proto( + subcomp.proto + ), True, ) - comp, _ = transformation_utils.transform_postorder(comp, _unpack) + comp, _ = federated_language.framework.transform_postorder(comp, _unpack) return comp @@ -293,7 +301,7 @@ class ExternalBlockToTensorFlowError(ValueError): def _evaluate_to_tensorflow( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, bindings: dict[str, object], ) -> object: """Evaluates `comp` within a TensorFlow context, returning a tensor structure. @@ -319,21 +327,23 @@ def _evaluate_to_tensorflow( ValueError: If `comp` contains `CompiledCompilations` other than TensorFlow or XLA. """ - if isinstance(comp, building_blocks.Block): + if isinstance(comp, federated_language.framework.Block): for name, value in comp.locals: bindings[name] = _evaluate_to_tensorflow(value, bindings) return _evaluate_to_tensorflow(comp.result, bindings) - if isinstance(comp, building_blocks.CompiledComputation): + if isinstance(comp, federated_language.framework.CompiledComputation): kind = comp.proto.WhichOneof('computation') if kind == 'tensorflow': def call_concrete(*args): - concrete = computation_impl.ConcreteComputation( + concrete = federated_language.framework.ConcreteComputation( computation_proto=comp.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) result = concrete(*args) - if isinstance(comp.type_signature.result, computation_types.StructType): + if isinstance( + comp.type_signature.result, federated_language.StructType + ): return structure.from_container(result, recursive=True) return result @@ -343,13 +353,13 @@ def call_concrete(*args): f'Cannot compile XLA subcomputation to TensorFlow:\n{comp}' ) raise ValueError(f'Unexpected compiled computation kind:\n{kind}') - if isinstance(comp, building_blocks.Call): + if isinstance(comp, federated_language.framework.Call): function = _evaluate_to_tensorflow(comp.function, bindings) if comp.argument is None: return function() else: return function(_evaluate_to_tensorflow(comp.argument, bindings)) - if isinstance(comp, building_blocks.Lambda): + if isinstance(comp, federated_language.framework.Lambda): if comp.parameter_type is None: return lambda: _evaluate_to_tensorflow(comp.result, bindings) else: @@ -359,23 +369,23 @@ def lambda_function(arg): return _evaluate_to_tensorflow(comp.result, bindings) return lambda_function - if isinstance(comp, building_blocks.Reference): + if isinstance(comp, federated_language.framework.Reference): return bindings[comp.name] - if isinstance(comp, building_blocks.Selection): + if isinstance(comp, federated_language.framework.Selection): return _evaluate_to_tensorflow(comp.source, bindings)[comp.as_index()] - if isinstance(comp, building_blocks.Struct): + if isinstance(comp, federated_language.framework.Struct): elements = [] for name, element in structure.iter_elements(comp): elements.append((name, _evaluate_to_tensorflow(element, bindings))) return structure.Struct(elements) - if isinstance(comp, building_blocks.Literal): + if isinstance(comp, federated_language.framework.Literal): return tf.constant(comp.value) if isinstance( comp, ( - building_blocks.Intrinsic, - building_blocks.Data, - building_blocks.Placement, + federated_language.framework.Intrinsic, + federated_language.framework.Data, + federated_language.framework.Placement, ), ): raise ExternalBlockToTensorFlowError( @@ -385,37 +395,40 @@ def lambda_function(arg): def compile_local_computation_to_tensorflow( - comp: building_blocks.ComputationBuildingBlock, -) -> building_blocks.ComputationBuildingBlock: + comp: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.ComputationBuildingBlock: """Compiles a fully specified local computation to TensorFlow. Args: - comp: A `building_blocks.ComputationBuildingBlock` which can be compiled to - TensorFlow. In order to compile a computation to TensorFlow, it must not - contain 1. References to values defined outside of comp, 2. `Data`, - `Intrinsic`, or `Placement` blocks, or 3. Calls to intrinsics or + comp: A `federated_language.framework.ComputationBuildingBlock` which can be + compiled to TensorFlow. In order to compile a computation to TensorFlow, + it must not contain 1. References to values defined outside of comp, 2. + `Data`, `Intrinsic`, or `Placement` blocks, or 3. Calls to intrinsics or non-TensorFlow computations. Returns: - A `building_blocks.ComputationBuildingBlock` containing a TensorFlow-only + A `federated_language.framework.ComputationBuildingBlock` containing a + TensorFlow-only representation of `comp`. If `comp` is of functional type, this will be - a `building_blocks.CompiledComputation`. Otherwise, it will be a - `building_blocks.Call` which wraps a `building_blocks.CompiledComputation`. + a `federated_language.framework.CompiledComputation`. Otherwise, it will be + a + `federated_language.framework.Call` which wraps a + `federated_language.framework.CompiledComputation`. """ - if not isinstance(comp.type_signature, computation_types.FunctionType): - lambda_wrapped = building_blocks.Lambda(None, None, comp) - return building_blocks.Call( + if not isinstance(comp.type_signature, federated_language.FunctionType): + lambda_wrapped = federated_language.framework.Lambda(None, None, comp) + return federated_language.framework.Call( compile_local_computation_to_tensorflow(lambda_wrapped), None ) parameter_type = comp.type_signature.parameter # pytype: disable=attribute-error - type_analysis.check_tensorflow_compatible_type(parameter_type) - type_analysis.check_tensorflow_compatible_type( + federated_language.framework.check_tensorflow_compatible_type(parameter_type) + federated_language.framework.check_tensorflow_compatible_type( comp.type_signature.result # pytype: disable=attribute-error ) if ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'tensorflow' ): return comp @@ -426,22 +439,23 @@ def compile_local_computation_to_tensorflow( comp = transformations.to_call_dominant(comp) if parameter_type is None: - to_evaluate = building_blocks.Call(comp) + to_evaluate = federated_language.framework.Call(comp) @tensorflow_computation.tf_computation def result_computation(): return _evaluate_to_tensorflow(to_evaluate, {}) else: - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) parameter_name = next(name_generator) - to_evaluate = building_blocks.Call( - comp, building_blocks.Reference(parameter_name, parameter_type) + to_evaluate = federated_language.framework.Call( + comp, + federated_language.framework.Reference(parameter_name, parameter_type), ) @tensorflow_computation.tf_computation(parameter_type) def result_computation(arg): - if isinstance(parameter_type, computation_types.StructType): + if isinstance(parameter_type, federated_language.StructType): arg = structure.from_container(arg, recursive=True) return _evaluate_to_tensorflow(to_evaluate, {parameter_name: arg}) @@ -449,8 +463,8 @@ def result_computation(arg): def compile_local_subcomputations_to_tensorflow( - comp: building_blocks.ComputationBuildingBlock, -) -> building_blocks.ComputationBuildingBlock: + comp: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.ComputationBuildingBlock: """Compiles subcomputations to TensorFlow where possible.""" comp = unpack_compiled_computations(comp) local_cache = {} @@ -462,15 +476,17 @@ def _is_local(comp): if isinstance( comp, ( - building_blocks.Intrinsic, - building_blocks.Data, - building_blocks.Placement, + federated_language.framework.Intrinsic, + federated_language.framework.Data, + federated_language.framework.Placement, ), - ) or type_analysis.contains_federated_types(comp.type_signature): + ) or federated_language.framework.contains_federated_types( + comp.type_signature + ): local_cache[comp] = False return False if ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'xla' ): local_cache[comp] = False @@ -481,7 +497,9 @@ def _is_local(comp): return False return True - unbound_ref_map = transformation_utils.get_map_of_unbound_references(comp) + unbound_ref_map = federated_language.framework.get_map_of_unbound_references( + comp + ) def _compile_if_local(comp): if _is_local(comp) and not unbound_ref_map[comp]: @@ -492,7 +510,9 @@ def _compile_if_local(comp): # first transformed to TensorFlow if they have a parent local computation # which could have instead been transformed into a larger single block of # TensorFlow. - comp, _ = transformation_utils.transform_preorder(comp, _compile_if_local) + comp, _ = federated_language.framework.transform_preorder( + comp, _compile_if_local + ) return comp @@ -500,12 +520,13 @@ def parse_tff_to_tf(comp, grappler_config_proto): """Parses TFF construct `comp` into TensorFlow construct. Does not change the type signature of `comp`. Therefore may return either - a `building_blocks.CompiledComputation` or a `building_blocks.Call` with no - argument and function `building_blocks.CompiledComputation`. + a `federated_language.framework.CompiledComputation` or a + `federated_language.framework.Call` with no + argument and function `federated_language.framework.CompiledComputation`. Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` to parse down - to a single TF block. + comp: Instance of `federated_language.framework.ComputationBuildingBlock` to + parse down to a single TF block. grappler_config_proto: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the generated TensorFlow graph. If `grappler_config_proto` has @@ -514,7 +535,8 @@ def parse_tff_to_tf(comp, grappler_config_proto): Returns: The result of parsing TFF to TF. If successful, this is either a single - `building_blocks.CompiledComputation`, or a call to one. If unsuccessful, + `federated_language.framework.CompiledComputation`, or a call to one. If + unsuccessful, there may be more TFF constructs still remaining. Notice it is not the job of this function, but rather its callers, to check that the result of this parse is as expected. @@ -549,24 +571,26 @@ def concatenate_function_outputs(first_function, second_function): these functions in parallel and concatenating the outputs in a tuple. Args: - first_function: Instance of `building_blocks.Lambda` whose result we wish to - concatenate with the result of `second_function`. - second_function: Instance of `building_blocks.Lambda` whose result we wish - to concatenate with the result of `first_function`. + first_function: Instance of `federated_language.framework.Lambda` whose + result we wish to concatenate with the result of `second_function`. + second_function: Instance of `federated_language.framework.Lambda` whose + result we wish to concatenate with the result of `first_function`. Returns: - A new instance of `building_blocks.Lambda` with unique names representing + A new instance of `federated_language.framework.Lambda` with unique names + representing the computation described above. Raises: - TypeError: If the arguments are not instances of `building_blocks.Lambda`, + TypeError: If the arguments are not instances of + `federated_language.framework.Lambda`, or declare parameters of different types. """ - py_typecheck.check_type(first_function, building_blocks.Lambda) - py_typecheck.check_type(second_function, building_blocks.Lambda) - tree_analysis.check_has_unique_names(first_function) - tree_analysis.check_has_unique_names(second_function) + py_typecheck.check_type(first_function, federated_language.framework.Lambda) + py_typecheck.check_type(second_function, federated_language.framework.Lambda) + federated_language.framework.check_has_unique_names(first_function) + federated_language.framework.check_has_unique_names(second_function) if first_function.parameter_type != second_function.parameter_type: raise TypeError( @@ -580,7 +604,7 @@ def concatenate_function_outputs(first_function, second_function): def _rename_first_function_arg(comp): if ( - isinstance(comp, building_blocks.Reference) + isinstance(comp, federated_language.framework.Reference) and comp.name == first_function.parameter_name ): if comp.type_signature != second_function.parameter_type: @@ -588,21 +612,23 @@ def _rename_first_function_arg(comp): '{}, {}'.format(comp.type_signature, second_function.parameter_type) ) return ( - building_blocks.Reference( + federated_language.framework.Reference( second_function.parameter_name, comp.type_signature ), True, ) return comp, False - first_function, _ = transformation_utils.transform_postorder( + first_function, _ = federated_language.framework.transform_postorder( first_function, _rename_first_function_arg ) - concatenated_function = building_blocks.Lambda( + concatenated_function = federated_language.framework.Lambda( second_function.parameter_name, second_function.parameter_type, - building_blocks.Struct([first_function.result, second_function.result]), + federated_language.framework.Struct( + [first_function.result, second_function.result] + ), ) renamed, _ = tree_transformations.uniquify_reference_names( diff --git a/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py b/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py index b8fd425a08..b3b1c34d0a 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py +++ b/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py @@ -13,38 +13,24 @@ # limitations under the License. from absl.testing import absltest +import federated_language +from federated_language.proto import computation_pb2 import numpy as np import tensorflow as tf - -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.core.backends.mapreduce import compiler from tensorflow_federated.python.core.backends.mapreduce import form_utils from tensorflow_federated.python.core.backends.mapreduce import mapreduce_test_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.compiler import building_block_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import set_default_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executor_stacks import executor_factory # pylint: enable=line-too-long -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import type_test_utils DEFAULT_GRAPPLER_CONFIG = tf.compat.v1.ConfigProto() def _create_test_context(): factory = executor_factory.local_cpp_executor_factory() - return sync_execution_context.SyncExecutionContext( + return federated_language.framework.SyncExecutionContext( executor_fn=factory, transform_args=tensorflow_computation.transform_args, transform_result=tensorflow_computation.transform_result, @@ -67,18 +53,20 @@ def get_function_from_first_symbol_binding_in_lambda_result(self, tree): Returns: Inner function value described above. """ - self.assertIsInstance(tree, building_blocks.Lambda) + self.assertIsInstance(tree, federated_language.framework.Lambda) self.assertIsNone(tree.parameter_type) - self.assertIsInstance(tree.result, building_blocks.Block) + self.assertIsInstance(tree.result, federated_language.framework.Block) comp_to_return = tree.result.locals[0][1] - self.assertIsInstance(comp_to_return, building_blocks.Call) + self.assertIsInstance(comp_to_return, federated_language.framework.Call) return comp_to_return.function def compiled_computation_for_initialize(self, initialize): # Create a federated version of initialize. - @federated_computation.federated_computation + @federated_language.federated_computation def federated_initialize_computation(): - return intrinsics.federated_value(initialize(), placements.SERVER) + return federated_language.federated_value( + initialize(), federated_language.SERVER + ) block = federated_initialize_computation.to_building_block() return self.get_function_from_first_symbol_binding_in_lambda_result(block) @@ -86,19 +74,19 @@ def federated_initialize_computation(): def test_raises_on_none_args(self): with self.assertRaisesRegex(TypeError, 'None'): compiler.check_extraction_result( - None, building_blocks.Reference('x', np.int32) + None, federated_language.framework.Reference('x', np.int32) ) with self.assertRaisesRegex(TypeError, 'None'): compiler.check_extraction_result( - building_blocks.Reference('x', np.int32), None + federated_language.framework.Reference('x', np.int32), None ) def test_raises_function_and_call(self): - function = building_blocks.Reference( - 'f', computation_types.FunctionType(np.int32, np.int32) + function = federated_language.framework.Reference( + 'f', federated_language.FunctionType(np.int32, np.int32) ) - integer_ref = building_blocks.Reference('x', np.int32) - call = building_blocks.Call(function, integer_ref) + integer_ref = federated_language.framework.Reference('x', np.int32) + call = federated_language.framework.Call(function, integer_ref) with self.assertRaisesRegex( compiler.MapReduceFormCompilationError, 'we have the functional type' ): @@ -110,7 +98,7 @@ def test_raises_non_function_and_compiled_computation(self): ) init = form_utils.get_state_initialization_computation(initialize) compiled_computation = self.compiled_computation_for_initialize(init) - integer_ref = building_blocks.Reference('x', np.int32) + integer_ref = federated_language.framework.Reference('x', np.int32) with self.assertRaisesRegex( compiler.MapReduceFormCompilationError, 'we have the non-functional type', @@ -123,8 +111,8 @@ def test_raises_function_and_compiled_computation_of_different_type(self): ) init = form_utils.get_state_initialization_computation(initialize) compiled_computation = self.compiled_computation_for_initialize(init) - function = building_blocks.Reference( - 'f', computation_types.FunctionType(np.int32, np.int32) + function = federated_language.framework.Reference( + 'f', federated_language.FunctionType(np.int32, np.int32) ) with self.assertRaisesRegex( compiler.MapReduceFormCompilationError, 'incorrect TFF type' @@ -132,11 +120,11 @@ def test_raises_function_and_compiled_computation_of_different_type(self): compiler.check_extraction_result(function, compiled_computation) def test_raises_tensor_and_call_to_not_compiled_computation(self): - function = building_blocks.Reference( - 'f', computation_types.FunctionType(np.int32, np.int32) + function = federated_language.framework.Reference( + 'f', federated_language.FunctionType(np.int32, np.int32) ) - ref_to_int = building_blocks.Reference('x', np.int32) - called_fn = building_blocks.Call(function, ref_to_int) + ref_to_int = federated_language.framework.Reference('x', np.int32) + called_fn = federated_language.framework.Call(function, ref_to_int) with self.assertRaisesRegex( compiler.MapReduceFormCompilationError, 'missing' ): @@ -148,7 +136,7 @@ def test_passes_function_and_compiled_computation_of_same_type(self): ) init = form_utils.get_state_initialization_computation(initialize) compiled_computation = self.compiled_computation_for_initialize(init) - function = building_blocks.Reference( + function = federated_language.framework.Reference( 'f', compiled_computation.type_signature ) compiler.check_extraction_result(function, compiled_computation) @@ -174,37 +162,43 @@ def test_already_reduced_case(self): comp, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(result, building_blocks.CompiledComputation) + self.assertIsInstance( + result, federated_language.framework.CompiledComputation + ) self.assertIsInstance(result.proto, computation_pb2.Computation) self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow') def test_reduces_unplaced_lambda_leaving_type_signature_alone(self): - lam = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + lam = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) ) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation + ) self.assertEqual(extracted_tf.type_signature, lam.type_signature) def test_further_concretizes_type_if_possible(self): - unk_size_int_type = computation_types.TensorType( + unk_size_int_type = federated_language.TensorType( dtype=np.int32, shape=[None] ) - known_size_int_type = computation_types.TensorType( + known_size_int_type = federated_language.TensorType( dtype=np.int32, shape=[1] ) - lam_with_unk_size = building_blocks.Lambda( + lam_with_unk_size = federated_language.framework.Lambda( 'x', unk_size_int_type, - building_blocks.Reference('x', unk_size_int_type), + federated_language.framework.Reference('x', unk_size_int_type), ) - known_size_ref = building_blocks.Reference('y', known_size_int_type) - called_identity_knowable_size = building_blocks.Call( + known_size_ref = federated_language.framework.Reference( + 'y', known_size_int_type + ) + called_identity_knowable_size = federated_language.framework.Call( lam_with_unk_size, known_size_ref ) - lam_with_knowable_size = building_blocks.Lambda( + lam_with_knowable_size = federated_language.framework.Lambda( known_size_ref.name, known_size_ref.type_signature, called_identity_knowable_size, @@ -214,10 +208,12 @@ def test_further_concretizes_type_if_possible(self): lam_with_knowable_size, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation + ) # Assert assignability only goes one way in this case--the compiler can # concretize the type of the lambda further. - type_test_utils.assert_type_assignable_from( + federated_language.framework.assert_type_assignable_from( lam_with_knowable_size.type_signature, extracted_tf.type_signature ) self.assertFalse( @@ -227,94 +223,128 @@ def test_further_concretizes_type_if_possible(self): ) def test_reduces_unplaced_lambda_to_equivalent_tf(self): - lam = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + lam = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) ) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG ) - executable_tf = computation_impl.ConcreteComputation.from_building_block( - extracted_tf + executable_tf = ( + federated_language.framework.ConcreteComputation.from_building_block( + extracted_tf + ) ) - executable_lam = computation_impl.ConcreteComputation.from_building_block( - lam + executable_lam = ( + federated_language.framework.ConcreteComputation.from_building_block( + lam + ) ) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k)) def test_reduces_federated_identity_to_member_identity(self): - fed_int_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - lam = building_blocks.Lambda( - 'x', fed_int_type, building_blocks.Reference('x', fed_int_type) + fed_int_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + lam = federated_language.framework.Lambda( + 'x', + fed_int_type, + federated_language.framework.Reference('x', fed_int_type), ) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) - unplaced_function_type = computation_types.FunctionType( + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation + ) + unplaced_function_type = federated_language.FunctionType( fed_int_type.member, fed_int_type.member ) self.assertEqual(extracted_tf.type_signature, unplaced_function_type) def test_reduces_federated_map_to_equivalent_function(self): - lam = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + lam = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) ) - arg_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - arg = building_blocks.Reference('arg', arg_type) - map_block = building_block_factory.create_federated_map_or_apply(lam, arg) - mapping_fn = building_blocks.Lambda('arg', arg_type, map_block) + arg_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + arg = federated_language.framework.Reference('arg', arg_type) + map_block = federated_language.framework.create_federated_map_or_apply( + lam, arg + ) + mapping_fn = federated_language.framework.Lambda('arg', arg_type, map_block) extracted_tf = compiler.consolidate_and_extract_local_processing( mapping_fn, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) - executable_tf = computation_impl.ConcreteComputation.from_building_block( - extracted_tf + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation ) - executable_lam = computation_impl.ConcreteComputation.from_building_block( - lam + executable_tf = ( + federated_language.framework.ConcreteComputation.from_building_block( + extracted_tf + ) + ) + executable_lam = ( + federated_language.framework.ConcreteComputation.from_building_block( + lam + ) ) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k)) def test_reduces_federated_apply_to_equivalent_function(self): - lam = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + lam = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) + ) + arg_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - arg_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - arg = building_blocks.Reference('arg', arg_type) - map_block = building_block_factory.create_federated_map_or_apply(lam, arg) - mapping_fn = building_blocks.Lambda('arg', arg_type, map_block) + arg = federated_language.framework.Reference('arg', arg_type) + map_block = federated_language.framework.create_federated_map_or_apply( + lam, arg + ) + mapping_fn = federated_language.framework.Lambda('arg', arg_type, map_block) extracted_tf = compiler.consolidate_and_extract_local_processing( mapping_fn, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) - executable_tf = computation_impl.ConcreteComputation.from_building_block( - extracted_tf + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation ) - executable_lam = computation_impl.ConcreteComputation.from_building_block( - lam + executable_tf = ( + federated_language.framework.ConcreteComputation.from_building_block( + extracted_tf + ) + ) + executable_lam = ( + federated_language.framework.ConcreteComputation.from_building_block( + lam + ) ) for k in range(10): self.assertEqual(executable_tf(k), executable_lam(k)) def test_reduces_federated_value_at_server_to_equivalent_noarg_function(self): zero_proto, zero_type = tensorflow_computation_factory.create_constant( - 0, computation_types.TensorType(np.int32) + 0, federated_language.TensorType(np.int32) ) - zero_compiled = building_blocks.CompiledComputation( + zero_compiled = federated_language.framework.CompiledComputation( zero_proto, type_signature=zero_type ) - zero = building_blocks.Call(zero_compiled, None) - federated_value = building_block_factory.create_federated_value( - zero, placements.SERVER + zero = federated_language.framework.Call(zero_compiled, None) + federated_value = federated_language.framework.create_federated_value( + zero, federated_language.SERVER + ) + federated_value_func = federated_language.framework.Lambda( + None, None, federated_value ) - federated_value_func = building_blocks.Lambda(None, None, federated_value) extracted_tf = compiler.consolidate_and_extract_local_processing( federated_value_func, DEFAULT_GRAPPLER_CONFIG ) - executable_tf = computation_impl.ConcreteComputation.from_building_block( - extracted_tf + executable_tf = ( + federated_language.framework.ConcreteComputation.from_building_block( + extracted_tf + ) ) self.assertEqual(executable_tf(), 0) @@ -322,94 +352,117 @@ def test_reduces_federated_value_at_clients_to_equivalent_noarg_function( self, ): zero_proto, zero_type = tensorflow_computation_factory.create_constant( - 0, computation_types.TensorType(np.int32) + 0, federated_language.TensorType(np.int32) ) - zero_compiled = building_blocks.CompiledComputation( + zero_compiled = federated_language.framework.CompiledComputation( zero_proto, type_signature=zero_type ) - zero = building_blocks.Call(zero_compiled, None) - federated_value = building_block_factory.create_federated_value( - zero, placements.CLIENTS + zero = federated_language.framework.Call(zero_compiled, None) + federated_value = federated_language.framework.create_federated_value( + zero, federated_language.CLIENTS + ) + federated_value_func = federated_language.framework.Lambda( + None, None, federated_value ) - federated_value_func = building_blocks.Lambda(None, None, federated_value) extracted_tf = compiler.consolidate_and_extract_local_processing( federated_value_func, DEFAULT_GRAPPLER_CONFIG ) - executable_tf = computation_impl.ConcreteComputation.from_building_block( - extracted_tf + executable_tf = ( + federated_language.framework.ConcreteComputation.from_building_block( + extracted_tf + ) ) self.assertEqual(executable_tf(), 0) def test_reduces_generic_intrinsic_to_equivalent_tf_op(self): - arg_type = computation_types.FederatedType(np.int32, placements.SERVER) - arg = building_blocks.Reference('arg', arg_type) - multiply_intrinsic = building_blocks.Intrinsic( - intrinsic_defs.GENERIC_MULTIPLY.uri, - computation_types.FunctionType([arg_type, arg_type], arg_type), + arg_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - multiply_fn = building_blocks.Lambda( + arg = federated_language.framework.Reference('arg', arg_type) + multiply_intrinsic = federated_language.framework.Intrinsic( + federated_language.framework.GENERIC_MULTIPLY.uri, + federated_language.FunctionType([arg_type, arg_type], arg_type), + ) + multiply_fn = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Call( - multiply_intrinsic, building_blocks.Struct([arg, arg]) + federated_language.framework.Call( + multiply_intrinsic, federated_language.framework.Struct([arg, arg]) ), ) extracted_tf = compiler.consolidate_and_extract_local_processing( multiply_fn, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) - executable_tf = computation_impl.ConcreteComputation.from_building_block( - extracted_tf + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation + ) + executable_tf = ( + federated_language.framework.ConcreteComputation.from_building_block( + extracted_tf + ) ) for k in range(10): self.assertEqual(executable_tf(k), k * k) def test_reduces_lambda_returning_empty_tuple_to_tf(self): - empty_tuple = building_blocks.Struct([]) - lam = building_blocks.Lambda('x', np.int32, empty_tuple) + empty_tuple = federated_language.framework.Struct([]) + lam = federated_language.framework.Lambda('x', np.int32, empty_tuple) extracted_tf = compiler.consolidate_and_extract_local_processing( lam, DEFAULT_GRAPPLER_CONFIG ) - self.assertIsInstance(extracted_tf, building_blocks.CompiledComputation) + self.assertIsInstance( + extracted_tf, federated_language.framework.CompiledComputation + ) class CompileLocalComputationToTensorFlow(absltest.TestCase): def assert_compiles_to_tensorflow( - self, comp: building_blocks.ComputationBuildingBlock + self, comp: federated_language.framework.ComputationBuildingBlock ): result = compiler.compile_local_computation_to_tensorflow(comp) - if isinstance(comp.type_signature, computation_types.FunctionType): - if not isinstance(result, building_blocks.CompiledComputation): + if isinstance(comp.type_signature, federated_language.FunctionType): + if not isinstance( + result, federated_language.framework.CompiledComputation + ): raise ValueError( - 'Expected a `building_blocks.CompiledComputation`, found' - f' {type(result)}.' + 'Expected a `federated_language.framework.CompiledComputation`,' + f' found {type(result)}.' ) else: - if not isinstance(result, building_blocks.Call): + if not isinstance(result, federated_language.framework.Call): raise ValueError( - f'Expected a `building_blocks.Call`, found {type(result)}.' + 'Expected a `federated_language.framework.Call`, found' + f' {type(result)}.' ) - if not isinstance(result.function, building_blocks.CompiledComputation): + if not isinstance( + result.function, federated_language.framework.CompiledComputation + ): raise ValueError( - 'Expected a `building_blocks.CompiledComputation`, found' - f' {type(result.function)}.' + 'Expected a `federated_language.framework.CompiledComputation`,' + f' found {type(result.function)}.' ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( comp.type_signature, result.type_signature ) def test_returns_tf_computation_with_functional_type_lambda_no_block(self): - param = building_blocks.Reference('x', [('a', np.int32), ('b', np.float32)]) - sel = building_blocks.Selection(source=param, index=0) - tup = building_blocks.Struct([sel, sel, sel]) - lam = building_blocks.Lambda(param.name, param.type_signature, tup) + param = federated_language.framework.Reference( + 'x', [('a', np.int32), ('b', np.float32)] + ) + sel = federated_language.framework.Selection(source=param, index=0) + tup = federated_language.framework.Struct([sel, sel, sel]) + lam = federated_language.framework.Lambda( + param.name, param.type_signature, tup + ) self.assert_compiles_to_tensorflow(lam) def test_returns_tf_computation_with_functional_type_lambda_with_block(self): - param = building_blocks.Reference('x', [('a', np.int32), ('b', np.float32)]) - block_to_param = building_blocks.Block([('x', param)], param) - lam = building_blocks.Lambda( + param = federated_language.framework.Reference( + 'x', [('a', np.int32), ('b', np.float32)] + ) + block_to_param = federated_language.framework.Block([('x', param)], param) + lam = federated_language.framework.Lambda( param.name, param.type_signature, block_to_param ) self.assert_compiles_to_tensorflow(lam) @@ -417,100 +470,104 @@ def test_returns_tf_computation_with_functional_type_lambda_with_block(self): def test_returns_tf_computation_with_functional_type_block_to_lambda_no_block( self, ): - concrete_int_type = computation_types.TensorType(np.int32) - param = building_blocks.Reference('x', np.float32) - lam = building_blocks.Lambda(param.name, param.type_signature, param) + concrete_int_type = federated_language.TensorType(np.int32) + param = federated_language.framework.Reference('x', np.float32) + lam = federated_language.framework.Lambda( + param.name, param.type_signature, param + ) unused_proto, unused_type = tensorflow_computation_factory.create_constant( 1, concrete_int_type ) - unused_compiled = building_blocks.CompiledComputation( + unused_compiled = federated_language.framework.CompiledComputation( unused_proto, type_signature=unused_type ) - unused_int = building_blocks.Call(unused_compiled, None) - blk_to_lam = building_blocks.Block([('y', unused_int)], lam) + unused_int = federated_language.framework.Call(unused_compiled, None) + blk_to_lam = federated_language.framework.Block([('y', unused_int)], lam) self.assert_compiles_to_tensorflow(blk_to_lam) def test_returns_tf_computation_with_functional_type_block_to_lambda_with_block( self, ): - concrete_int_type = computation_types.TensorType(np.int32) - param = building_blocks.Reference('x', np.float32) - block_to_param = building_blocks.Block([('x', param)], param) - lam = building_blocks.Lambda( + concrete_int_type = federated_language.TensorType(np.int32) + param = federated_language.framework.Reference('x', np.float32) + block_to_param = federated_language.framework.Block([('x', param)], param) + lam = federated_language.framework.Lambda( param.name, param.type_signature, block_to_param ) unused_proto, unused_type = tensorflow_computation_factory.create_constant( 1, concrete_int_type ) - unused_compiled = building_blocks.CompiledComputation( + unused_compiled = federated_language.framework.CompiledComputation( unused_proto, type_signature=unused_type ) - unused_int = building_blocks.Call(unused_compiled, None) - blk_to_lam = building_blocks.Block([('y', unused_int)], lam) + unused_int = federated_language.framework.Call(unused_compiled, None) + blk_to_lam = federated_language.framework.Block([('y', unused_int)], lam) self.assert_compiles_to_tensorflow(blk_to_lam) def test_returns_tf_computation_block_with_compiled_comp(self): - concrete_int_type = computation_types.TensorType(np.int32) + concrete_int_type = federated_language.TensorType(np.int32) proto, type_signature = tensorflow_computation_factory.create_identity( concrete_int_type ) - tf_identity = building_blocks.CompiledComputation( + tf_identity = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) unused_proto, unused_type = tensorflow_computation_factory.create_constant( 1, concrete_int_type ) - unused_compiled = building_blocks.CompiledComputation( + unused_compiled = federated_language.framework.CompiledComputation( unused_proto, type_signature=unused_type ) - unused_int = building_blocks.Call(unused_compiled, None) - block_to_id = building_blocks.Block([('x', unused_int)], tf_identity) + unused_int = federated_language.framework.Call(unused_compiled, None) + block_to_id = federated_language.framework.Block( + [('x', unused_int)], tf_identity + ) self.assert_compiles_to_tensorflow(block_to_id) def test_returns_tf_computation_compiled_comp(self): - concrete_int_type = computation_types.TensorType(np.int32) + concrete_int_type = federated_language.TensorType(np.int32) proto, type_signature = tensorflow_computation_factory.create_identity( concrete_int_type ) - tf_identity = building_blocks.CompiledComputation( + tf_identity = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) self.assert_compiles_to_tensorflow(tf_identity) def test_returns_called_tf_computation_with_truct(self): - constant_tuple_type = computation_types.StructType([np.int32, np.float32]) + constant_tuple_type = federated_language.StructType([np.int32, np.float32]) constant_proto, constant_type = ( tensorflow_computation_factory.create_constant(1, constant_tuple_type) ) - constant_compiled = building_blocks.CompiledComputation( + constant_compiled = federated_language.framework.CompiledComputation( constant_proto, type_signature=constant_type ) - constant_tuple = building_blocks.Call(constant_compiled, None) - sel = building_blocks.Selection(source=constant_tuple, index=0) - tup = building_blocks.Struct([sel, sel, sel]) + constant_tuple = federated_language.framework.Call(constant_compiled, None) + sel = federated_language.framework.Selection(source=constant_tuple, index=0) + tup = federated_language.framework.Struct([sel, sel, sel]) self.assert_compiles_to_tensorflow(tup) def test_passes_on_tf(self): proto, type_signature = tensorflow_computation_factory.create_identity( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ) - tf_comp = building_blocks.CompiledComputation( + tf_comp = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) transformed = compiler.compile_local_computation_to_tensorflow(tf_comp) self.assertEqual(tf_comp, transformed) def test_raises_on_xla(self): - function_type = computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), + function_type = federated_language.FunctionType( + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), ) empty_xla_computation_proto = computation_pb2.Computation( - type=type_serialization.serialize_type(function_type), + type=federated_language.framework.serialize_type(function_type), xla=computation_pb2.Xla(), ) - compiled_comp = building_blocks.CompiledComputation( + compiled_comp = federated_language.framework.CompiledComputation( proto=empty_xla_computation_proto ) @@ -518,50 +575,58 @@ def test_raises_on_xla(self): compiler.compile_local_computation_to_tensorflow(compiled_comp) def test_generates_tf_with_lambda(self): - ref_to_x = building_blocks.Reference( - 'x', computation_types.StructType([np.int32, np.float32]) + ref_to_x = federated_language.framework.Reference( + 'x', federated_language.StructType([np.int32, np.float32]) ) - identity_lambda = building_blocks.Lambda( + identity_lambda = federated_language.framework.Lambda( ref_to_x.name, ref_to_x.type_signature, ref_to_x ) self.assert_compiles_to_tensorflow(identity_lambda) def test_generates_tf_with_block(self): - ref_to_x = building_blocks.Reference( - 'x', computation_types.StructType([np.int32, np.float32]) + ref_to_x = federated_language.framework.Reference( + 'x', federated_language.StructType([np.int32, np.float32]) ) - identity_lambda = building_blocks.Lambda( + identity_lambda = federated_language.framework.Lambda( ref_to_x.name, ref_to_x.type_signature, ref_to_x ) zero_proto, zero_type = tensorflow_computation_factory.create_constant( - 0, computation_types.StructType([np.int32, np.float32]) + 0, federated_language.StructType([np.int32, np.float32]) ) - zero_compiled = building_blocks.CompiledComputation( + zero_compiled = federated_language.framework.CompiledComputation( zero_proto, type_signature=zero_type ) - zero = building_blocks.Call(zero_compiled, None) - ref_to_z = building_blocks.Reference('z', [np.int32, np.float32]) - called_lambda_on_z = building_blocks.Call(identity_lambda, ref_to_z) - blk = building_blocks.Block([('z', zero)], called_lambda_on_z) + zero = federated_language.framework.Call(zero_compiled, None) + ref_to_z = federated_language.framework.Reference( + 'z', [np.int32, np.float32] + ) + called_lambda_on_z = federated_language.framework.Call( + identity_lambda, ref_to_z + ) + blk = federated_language.framework.Block([('z', zero)], called_lambda_on_z) self.assert_compiles_to_tensorflow(blk) def test_generates_tf_with_sequence_type(self): - ref_to_x = building_blocks.Reference( - 'x', computation_types.SequenceType([np.int32, np.float32]) + ref_to_x = federated_language.framework.Reference( + 'x', federated_language.SequenceType([np.int32, np.float32]) ) - identity_lambda = building_blocks.Lambda( + identity_lambda = federated_language.framework.Lambda( ref_to_x.name, ref_to_x.type_signature, ref_to_x ) self.assert_compiles_to_tensorflow(identity_lambda) def test_returns_result_with_literal(self): - comp = building_blocks.Literal(1, computation_types.TensorType(np.int32)) + comp = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) result = compiler.compile_local_computation_to_tensorflow(comp) - self.assertIsInstance(result, building_blocks.Call) - self.assertIsInstance(result.function, building_blocks.CompiledComputation) - type_test_utils.assert_types_equivalent( + self.assertIsInstance(result, federated_language.framework.Call) + self.assertIsInstance( + result.function, federated_language.framework.CompiledComputation + ) + federated_language.framework.assert_types_equivalent( comp.type_signature, result.type_signature ) @@ -569,10 +634,11 @@ def test_returns_result_with_literal(self): class CompileLocalSubcomputationsToTensorFlowTest(absltest.TestCase): def test_leaves_federated_comp_alone(self): - ref_to_federated_x = building_blocks.Reference( - 'x', computation_types.FederatedType(np.int32, placements.SERVER) + ref_to_federated_x = federated_language.framework.Reference( + 'x', + federated_language.FederatedType(np.int32, federated_language.SERVER), ) - identity_lambda = building_blocks.Lambda( + identity_lambda = federated_language.framework.Lambda( ref_to_federated_x.name, ref_to_federated_x.type_signature, ref_to_federated_x, @@ -583,32 +649,35 @@ def test_leaves_federated_comp_alone(self): self.assertEqual(transformed, identity_lambda) def test_compiles_lambda_under_federated_comp_to_tf(self): - ref_to_x = building_blocks.Reference( - 'x', computation_types.StructType([np.int32, np.float32]) + ref_to_x = federated_language.framework.Reference( + 'x', federated_language.StructType([np.int32, np.float32]) ) - identity_lambda = building_blocks.Lambda( + identity_lambda = federated_language.framework.Lambda( ref_to_x.name, ref_to_x.type_signature, ref_to_x ) any_proto = building_block_test_utils.create_any_proto_from_array( np.array(1) ) - federated_data = building_blocks.Data( + federated_data = federated_language.framework.Data( any_proto, - computation_types.FederatedType( - computation_types.StructType([np.int32, np.float32]), - placements.SERVER, + federated_language.FederatedType( + federated_language.StructType([np.int32, np.float32]), + federated_language.SERVER, ), ) - applied = building_block_factory.create_federated_apply( + applied = federated_language.framework.create_federated_apply( identity_lambda, federated_data ) transformed = compiler.compile_local_subcomputations_to_tensorflow(applied) - self.assertIsInstance(transformed, building_blocks.Call) - self.assertIsInstance(transformed.function, building_blocks.Intrinsic) + self.assertIsInstance(transformed, federated_language.framework.Call) + self.assertIsInstance( + transformed.function, federated_language.framework.Intrinsic + ) self.assertIsInstance( - transformed.argument[0], building_blocks.CompiledComputation + transformed.argument[0], + federated_language.framework.CompiledComputation, ) self.assertEqual(transformed.argument[1], federated_data) self.assertEqual( @@ -616,9 +685,13 @@ def test_compiles_lambda_under_federated_comp_to_tf(self): ) def test_leaves_local_comp_with_unbound_reference_alone(self): - ref_to_x = building_blocks.Reference('x', [np.int32, np.float32]) - ref_to_z = building_blocks.Reference('z', [np.int32, np.float32]) - lambda_with_unbound_ref = building_blocks.Lambda( + ref_to_x = federated_language.framework.Reference( + 'x', [np.int32, np.float32] + ) + ref_to_z = federated_language.framework.Reference( + 'z', [np.int32, np.float32] + ) + lambda_with_unbound_ref = federated_language.framework.Lambda( ref_to_x.name, ref_to_x.type_signature, ref_to_z ) transformed = compiler.compile_local_subcomputations_to_tensorflow( @@ -631,61 +704,65 @@ def test_leaves_local_comp_with_unbound_reference_alone(self): class ConcatenateFunctionOutputsTest(absltest.TestCase): def test_raises_on_non_lambda_args(self): - reference = building_blocks.Reference('x', np.int32) - tff_lambda = building_blocks.Lambda('x', np.int32, reference) + reference = federated_language.framework.Reference('x', np.int32) + tff_lambda = federated_language.framework.Lambda('x', np.int32, reference) with self.assertRaises(TypeError): compiler.concatenate_function_outputs(tff_lambda, reference) with self.assertRaises(TypeError): compiler.concatenate_function_outputs(reference, tff_lambda) def test_raises_on_non_unique_names(self): - reference = building_blocks.Reference('x', np.int32) - good_lambda = building_blocks.Lambda('x', np.int32, reference) - bad_lambda = building_blocks.Lambda('x', np.int32, good_lambda) + reference = federated_language.framework.Reference('x', np.int32) + good_lambda = federated_language.framework.Lambda('x', np.int32, reference) + bad_lambda = federated_language.framework.Lambda('x', np.int32, good_lambda) with self.assertRaises(ValueError): compiler.concatenate_function_outputs(good_lambda, bad_lambda) with self.assertRaises(ValueError): compiler.concatenate_function_outputs(bad_lambda, good_lambda) def test_raises_on_different_parameter_types(self): - int_reference = building_blocks.Reference('x', np.int32) - int_lambda = building_blocks.Lambda('x', np.int32, int_reference) - float_reference = building_blocks.Reference('x', np.float32) - float_lambda = building_blocks.Lambda('x', np.float32, float_reference) + int_reference = federated_language.framework.Reference('x', np.int32) + int_lambda = federated_language.framework.Lambda( + 'x', np.int32, int_reference + ) + float_reference = federated_language.framework.Reference('x', np.float32) + float_lambda = federated_language.framework.Lambda( + 'x', np.float32, float_reference + ) with self.assertRaises(TypeError): compiler.concatenate_function_outputs(int_lambda, float_lambda) def test_parameters_are_mapped_together(self): - x_reference = building_blocks.Reference('x', np.int32) - x_lambda = building_blocks.Lambda('x', np.int32, x_reference) - y_reference = building_blocks.Reference('y', np.int32) - y_lambda = building_blocks.Lambda('y', np.int32, y_reference) + x_reference = federated_language.framework.Reference('x', np.int32) + x_lambda = federated_language.framework.Lambda('x', np.int32, x_reference) + y_reference = federated_language.framework.Reference('y', np.int32) + y_lambda = federated_language.framework.Lambda('y', np.int32, y_reference) concatenated = compiler.concatenate_function_outputs(x_lambda, y_lambda) parameter_name = concatenated.parameter_name def _raise_on_other_name_reference(comp): if ( - isinstance(comp, building_blocks.Reference) + isinstance(comp, federated_language.framework.Reference) and comp.name != parameter_name ): raise ValueError return comp, True - tree_analysis.check_has_unique_names(concatenated) - transformation_utils.transform_postorder( + federated_language.framework.check_has_unique_names(concatenated) + federated_language.framework.transform_postorder( concatenated, _raise_on_other_name_reference ) def test_concatenates_identities(self): - x_reference = building_blocks.Reference('x', np.int32) - x_lambda = building_blocks.Lambda('x', np.int32, x_reference) - y_reference = building_blocks.Reference('y', np.int32) - y_lambda = building_blocks.Lambda('y', np.int32, y_reference) + x_reference = federated_language.framework.Reference('x', np.int32) + x_lambda = federated_language.framework.Lambda('x', np.int32, x_reference) + y_reference = federated_language.framework.Reference('y', np.int32) + y_lambda = federated_language.framework.Lambda('y', np.int32, y_reference) concatenated = compiler.concatenate_function_outputs(x_lambda, y_lambda) self.assertEqual(str(concatenated), '(y -> )') if __name__ == '__main__': context = _create_test_context() - set_default_context.set_default_context(context) + federated_language.framework.set_default_context(context) absltest.main() diff --git a/tensorflow_federated/python/core/backends/mapreduce/distribute_aggregate_test_utils.py b/tensorflow_federated/python/core/backends/mapreduce/distribute_aggregate_test_utils.py index e2ed1cb8b0..8147bf3832 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/distribute_aggregate_test_utils.py +++ b/tensorflow_federated/python/core/backends/mapreduce/distribute_aggregate_test_utils.py @@ -15,20 +15,13 @@ import collections +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.mapreduce import forms from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements DistributeAggregateFormExample = collections.namedtuple( 'DistributeAggregateFormExample', ['daf', 'initialize'] @@ -36,38 +29,40 @@ def generate_unnamed_type_signature( - server_prepare: computation_impl.ConcreteComputation, - client_work: computation_impl.ConcreteComputation, - server_result: computation_impl.ConcreteComputation, -) -> computation_types.FunctionType: + server_prepare: federated_language.framework.ConcreteComputation, + client_work: federated_language.framework.ConcreteComputation, + server_result: federated_language.framework.ConcreteComputation, +) -> federated_language.FunctionType: """Generates a type signature for the DistributeAggregateForm.""" - parameter = computation_types.StructType([ + parameter = federated_language.StructType([ server_prepare.type_signature.parameter, client_work.type_signature.parameter[0], # pytype: disable=unsupported-operands ]) result = server_result.type_signature.result - return computation_types.FunctionType(parameter, result) + return federated_language.FunctionType(parameter, result) def _make_distribute_aggregate_form_example( - initialize: computation_impl.ConcreteComputation, - type_signature: computation_types.FunctionType, - server_prepare: computation_impl.ConcreteComputation, - server_to_client_broadcast: computation_impl.ConcreteComputation, - client_work: computation_impl.ConcreteComputation, - client_to_server_aggregation: computation_impl.ConcreteComputation, - server_result: computation_impl.ConcreteComputation, + initialize: federated_language.framework.ConcreteComputation, + type_signature: federated_language.FunctionType, + server_prepare: federated_language.framework.ConcreteComputation, + server_to_client_broadcast: federated_language.framework.ConcreteComputation, + client_work: federated_language.framework.ConcreteComputation, + client_to_server_aggregation: federated_language.framework.ConcreteComputation, + server_result: federated_language.framework.ConcreteComputation, ) -> DistributeAggregateFormExample: """Constructs a DistributeAggregateFormExample given the component comps.""" - def _uniquify_reference_names(comp: computation_impl.ConcreteComputation): + def _uniquify_reference_names( + comp: federated_language.framework.ConcreteComputation, + ): building_block = comp.to_building_block() transformed_comp = tree_transformations.uniquify_reference_names( building_block )[0] - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=transformed_comp.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) return DistributeAggregateFormExample( @@ -91,25 +86,30 @@ def get_temperature_sensor_example() -> DistributeAggregateFormExample: Returns: A tuple of: (1) an instance of `forms.DistributeAggregateForm` and (2) an - associated `computation_base.Computation` that generates an initial state + associated `federated_language.framework.Computation` that generates an + initial state compatible with the server state expected by the `forms.DistributeAggregateForm`. """ - @federated_computation.federated_computation() + @federated_language.federated_computation() def initialize(): @tensorflow_computation.tf_computation def initialize_tf(): return collections.OrderedDict(num_rounds=0) - return intrinsics.federated_value(initialize_tf(), placements.SERVER) + return federated_language.federated_value( + initialize_tf(), federated_language.SERVER + ) # The state of the server is a struct containing just the integer # counter `num_rounds`. server_state_type = [('num_rounds', np.int32)] - @federated_computation.federated_computation( - computation_types.FederatedType(server_state_type, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType( + server_state_type, federated_language.SERVER + ) ) def server_prepare(state): @tensorflow_computation.tf_computation(server_state_type) @@ -118,7 +118,9 @@ def server_prepare_tf(state): max_temperature=32.0 + tf.cast(state['num_rounds'], tf.float32) ) - broadcast_args = [intrinsics.federated_map(server_prepare_tf, state)] + broadcast_args = [ + federated_language.federated_map(server_prepare_tf, state) + ] intermediate_state = [state] return broadcast_args, intermediate_state @@ -128,21 +130,31 @@ def server_prepare_tf(state): # The intermediate state will contain the server state. intermediate_state_type = [ - computation_types.FederatedType(server_state_type, placements.SERVER) + federated_language.FederatedType( + server_state_type, federated_language.SERVER + ) ] - @federated_computation.federated_computation( - [computation_types.FederatedType(broadcast_type, placements.SERVER)] - ) + @federated_language.federated_computation([ + federated_language.FederatedType( + broadcast_type, federated_language.SERVER + ) + ]) def server_to_client_broadcast(context_at_server): - return [intrinsics.federated_broadcast(context_at_server[0])] + return [federated_language.federated_broadcast(context_at_server[0])] # The client data is a sequence of floats. - client_data_type = computation_types.SequenceType(np.float32) + client_data_type = federated_language.SequenceType(np.float32) - @federated_computation.federated_computation( - computation_types.FederatedType(client_data_type, placements.CLIENTS), - [computation_types.FederatedType(broadcast_type, placements.CLIENTS)], + @federated_language.federated_computation( + federated_language.FederatedType( + client_data_type, federated_language.CLIENTS + ), + [ + federated_language.FederatedType( + broadcast_type, federated_language.CLIENTS + ) + ], ) def client_work(data, context_at_client): @tensorflow_computation.tf_computation(client_data_type, [broadcast_type]) @@ -165,14 +177,14 @@ def fn(s, x): ) return client_updates - results = intrinsics.federated_map( + results = federated_language.federated_map( client_work_tf, (data, context_at_client) ) - unzipped_results = building_block_factory.create_federated_unzip( + unzipped_results = federated_language.framework.create_federated_unzip( results.comp ) - return value_impl.Value( - context_stack_impl.context_stack.current.bind_computation_to_reference( + return federated_language.Value( + federated_language.framework.global_context_stack.current.bind_computation_to_reference( unzipped_results ) ) @@ -181,29 +193,35 @@ def fn(s, x): federated_client_update_type = [ ( 'is_over', - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ), ( 'weight', - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ), ] - @federated_computation.federated_computation( + @federated_language.federated_computation( intermediate_state_type, federated_client_update_type ) def client_to_server_aggregation(intermediate_server_state, client_updates): del intermediate_server_state # Unused - return [intrinsics.federated_mean(client_updates[0], client_updates[1])] + return [ + federated_language.federated_mean(client_updates[0], client_updates[1]) + ] # The aggregation result type is a single float. aggregation_result_type = np.float32 - @federated_computation.federated_computation( + @federated_language.federated_computation( intermediate_state_type, [ - computation_types.FederatedType( - aggregation_result_type, placements.SERVER + federated_language.FederatedType( + aggregation_result_type, federated_language.SERVER ) ], ) @@ -212,7 +230,7 @@ def server_result(intermediate_server_state, aggregation_result): def server_result_tf(server_state): return collections.OrderedDict(num_rounds=server_state['num_rounds'] + 1) - return intrinsics.federated_map( + return federated_language.federated_map( server_result_tf, intermediate_server_state[0] ), collections.OrderedDict(ratio_over_threshold=aggregation_result[0]) @@ -240,7 +258,7 @@ def get_mnist_training_example() -> DistributeAggregateFormExample: server_state_nt = collections.namedtuple('ServerState', 'model num_rounds') # Start with a model filled with zeros, and the round counter set to zero. - @federated_computation.federated_computation() + @federated_language.federated_computation() def initialize(): @tensorflow_computation.tf_computation def initialize_tf(): @@ -249,7 +267,9 @@ def initialize_tf(): num_rounds=0, ) - return intrinsics.federated_value(initialize_tf(), placements.SERVER) + return federated_language.federated_value( + initialize_tf(), federated_language.SERVER + ) server_state_tff_type = server_state_nt( model=model_nt(weights=(np.float32, [784, 10]), bias=(np.float32, [10])), @@ -260,8 +280,10 @@ def initialize_tf(): # Prepare the broadcast input containing the model and a dynamically adjusted # learning rate that starts at 0.1 and decays exponentially by a factor of # 0.9. - @federated_computation.federated_computation( - computation_types.FederatedType(server_state_tff_type, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType( + server_state_tff_type, federated_language.SERVER + ) ) def server_prepare(state): @tensorflow_computation.tf_computation(server_state_tff_type) @@ -269,7 +291,9 @@ def server_prepare_tf(state): learning_rate = 0.1 * tf.pow(0.9, tf.cast(state.num_rounds, tf.float32)) return client_state_nt(model=state.model, learning_rate=learning_rate) - broadcast_args = [intrinsics.federated_map(server_prepare_tf, state)] + broadcast_args = [ + federated_language.federated_map(server_prepare_tf, state) + ] intermediate_state = [32, state] return broadcast_args, intermediate_state @@ -277,7 +301,9 @@ def server_prepare_tf(state): # sum and the server state. intermediate_state_type = [ np.int32, - computation_types.FederatedType(server_state_tff_type, placements.SERVER), + federated_language.FederatedType( + server_state_tff_type, federated_language.SERVER + ), ] model_tff_type = model_nt( @@ -287,24 +313,32 @@ def server_prepare_tf(state): model=model_tff_type, learning_rate=np.float32 ) - @federated_computation.federated_computation( - [computation_types.FederatedType(broadcast_tff_type, placements.SERVER)] - ) + @federated_language.federated_computation([ + federated_language.FederatedType( + broadcast_tff_type, federated_language.SERVER + ) + ]) def server_to_client_broadcast(context_at_server): - return [intrinsics.federated_broadcast(context_at_server[0])] + return [federated_language.federated_broadcast(context_at_server[0])] batch_nt = collections.namedtuple('Batch', 'x y') batch_tff_type = batch_nt(x=(np.float32, [None, 784]), y=(np.int32, [None])) - dataset_tff_type = computation_types.SequenceType(batch_tff_type) + dataset_tff_type = federated_language.SequenceType(batch_tff_type) loop_state_nt = collections.namedtuple('LoopState', 'num_examples total_loss') update_nt = collections.namedtuple('Update', 'model num_examples loss') # Train the model locally, emit the locally-trained model and the number of # examples as an update, and the average loss and the number of examples as # local client stats. - @federated_computation.federated_computation( - computation_types.FederatedType(dataset_tff_type, placements.CLIENTS), - [computation_types.FederatedType(broadcast_tff_type, placements.CLIENTS)], + @federated_language.federated_computation( + federated_language.FederatedType( + dataset_tff_type, federated_language.CLIENTS + ), + [ + federated_language.FederatedType( + broadcast_tff_type, federated_language.CLIENTS + ) + ], ) def client_work(data, context_at_client): @tensorflow_computation.tf_computation(dataset_tff_type, broadcast_tff_type) @@ -357,14 +391,15 @@ def reduce_fn(loop_state, batch): def cast_to_float(val): return tf.cast(val, tf.float32) - results = intrinsics.federated_map( - client_work_tf, intrinsics.federated_zip([data, context_at_client[0]]) + results = federated_language.federated_map( + client_work_tf, + federated_language.federated_zip([data, context_at_client[0]]), ) - unzipped_results = building_block_factory.create_federated_unzip( + unzipped_results = federated_language.framework.create_federated_unzip( results.comp ) - client_update = value_impl.Value( - context_stack_impl.context_stack.current.bind_computation_to_reference( + client_update = federated_language.Value( + federated_language.framework.global_context_stack.current.bind_computation_to_reference( unzipped_results ) ) @@ -372,51 +407,63 @@ def cast_to_float(val): return [ # input for federated_mean client_update.model, - intrinsics.federated_map(cast_to_float, client_update.num_examples), + federated_language.federated_map( + cast_to_float, client_update.num_examples + ), # input for federated_sum client_update.num_examples, # input for federated_mean client_update.loss, - intrinsics.federated_map(cast_to_float, client_update.num_examples), + federated_language.federated_map( + cast_to_float, client_update.num_examples + ), ] federated_aggregation_input_tff_type = [ # input for federated_mean - computation_types.FederatedType(model_tff_type, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType( + model_tff_type, federated_language.CLIENTS + ), + federated_language.FederatedType(np.float32, federated_language.CLIENTS), # input for federated_sum - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), # input for federated_mean - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType(np.float32, federated_language.CLIENTS), + federated_language.FederatedType(np.float32, federated_language.CLIENTS), ] - @federated_computation.federated_computation( + @federated_language.federated_computation( intermediate_state_type, federated_aggregation_input_tff_type ) def client_to_server_aggregation(intermediate_server_state, client_updates): - scaled_model = intrinsics.federated_mean( + scaled_model = federated_language.federated_mean( client_updates[0], client_updates[1] ) - num_examples = intrinsics.federated_secure_sum_bitwidth( + num_examples = federated_language.federated_secure_sum_bitwidth( client_updates[2], intermediate_server_state[0] ) - scaled_loss = intrinsics.federated_mean( + scaled_loss = federated_language.federated_mean( client_updates[3], client_updates[4] ) return [scaled_model, num_examples, scaled_loss] # The aggregation result type is a struct. federated_aggregation_result_type = update_nt( - model=computation_types.FederatedType(model_tff_type, placements.SERVER), - num_examples=computation_types.FederatedType(np.int32, placements.SERVER), - loss=computation_types.FederatedType(np.float32, placements.SERVER), + model=federated_language.FederatedType( + model_tff_type, federated_language.SERVER + ), + num_examples=federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + loss=federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ) metrics_nt = collections.namedtuple('Metrics', 'num_rounds num_examples loss') - @federated_computation.federated_computation( + @federated_language.federated_computation( intermediate_state_type, federated_aggregation_result_type, ) @@ -425,14 +472,14 @@ def server_result(intermediate_server_state, aggregation_result): def server_result_tf(state): return state.num_rounds + 1 - num_rounds = intrinsics.federated_map( + num_rounds = federated_language.federated_map( server_result_tf, intermediate_server_state[1] ) - new_server_state = intrinsics.federated_zip( + new_server_state = federated_language.federated_zip( server_state_nt(model=aggregation_result.model, num_rounds=num_rounds) ) - metrics = intrinsics.federated_zip( + metrics = federated_language.federated_zip( metrics_nt( num_rounds=num_rounds, num_examples=aggregation_result.num_examples, diff --git a/tensorflow_federated/python/core/backends/mapreduce/form_utils.py b/tensorflow_federated/python/core/backends/mapreduce/form_utils.py index de9f8bcde7..d32568331d 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/form_utils.py +++ b/tensorflow_federated/python/core/backends/mapreduce/form_utils.py @@ -20,6 +20,7 @@ from collections.abc import Callable from typing import Optional +import federated_language import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck @@ -29,20 +30,8 @@ from tensorflow_federated.python.core.backends.mapreduce import intrinsics as mapreduce_intrinsics from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils from tensorflow_federated.python.core.impl.compiler import transformations -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements _GRAPPLER_DEFAULT_CONFIG = tf.compat.v1.ConfigProto() _AGGRESSIVE = _GRAPPLER_DEFAULT_CONFIG.graph_options.rewrite_options.AGGRESSIVE @@ -63,54 +52,62 @@ ) BuildingBlockFn = Callable[ - [building_blocks.ComputationBuildingBlock], - building_blocks.ComputationBuildingBlock, + [federated_language.framework.ComputationBuildingBlock], + federated_language.framework.ComputationBuildingBlock, ] def get_computation_for_broadcast_form( bf: forms.BroadcastForm, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates `tff.Computation` from a broadcast form.""" py_typecheck.check_type(bf, forms.BroadcastForm) server_data_type = bf.compute_server_context.type_signature.parameter client_data_type = bf.client_processing.type_signature.parameter[1] - comp_parameter_type = computation_types.StructType([ + comp_parameter_type = federated_language.StructType([ ( bf.server_data_label, - computation_types.FederatedType(server_data_type, placements.SERVER), + federated_language.FederatedType( + server_data_type, federated_language.SERVER + ), ), ( bf.client_data_label, - computation_types.FederatedType(client_data_type, placements.CLIENTS), + federated_language.FederatedType( + client_data_type, federated_language.CLIENTS + ), ), ]) - @federated_computation.federated_computation(comp_parameter_type) + @federated_language.federated_computation(comp_parameter_type) def computation(arg): server_data, client_data = arg - context_at_server = intrinsics.federated_map( + context_at_server = federated_language.federated_map( bf.compute_server_context, server_data ) - context_at_clients = intrinsics.federated_broadcast(context_at_server) - client_processing_arg = intrinsics.federated_zip( + context_at_clients = federated_language.federated_broadcast( + context_at_server + ) + client_processing_arg = federated_language.federated_zip( (context_at_clients, client_data) ) - return intrinsics.federated_map(bf.client_processing, client_processing_arg) + return federated_language.federated_map( + bf.client_processing, client_processing_arg + ) return computation def get_state_initialization_computation( - initialize_computation: computation_impl.ConcreteComputation, + initialize_computation: federated_language.framework.ConcreteComputation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Validates and transforms a computation to generate state. Args: - initialize_computation: A `computation_impl.ConcreteComputation` that should - generate initial state for a computation that is compatible with a - federated learning system that implements the contract of a backend + initialize_computation: A `federated_language.framework.ConcreteComputation` + that should generate initial state for a computation that is compatible + with a federated learning system that implements the contract of a backend defined in the backends/mapreduce directory. grappler_config: An optional instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the TensorFlow graphs. These @@ -120,7 +117,8 @@ def get_state_initialization_computation( bypassed. Returns: - A `computation_base.Computation` that can generate state for a computation. + A `federated_language.framework.Computation` that can generate state for a + computation. Raises: TypeError: If the arguments are of the wrong types. @@ -128,8 +126,8 @@ def get_state_initialization_computation( init_type = initialize_computation.type_signature _check_type_is_no_arg_fn(init_type, '`initialize`', TypeError) if ( - not isinstance(init_type.result, computation_types.FederatedType) - or init_type.result.placement is not placements.SERVER + not isinstance(init_type.result, federated_language.FederatedType) + or init_type.result.placement is not federated_language.SERVER ): raise TypeError( 'Expected `initialize` to return a single federated value ' @@ -142,19 +140,21 @@ def get_state_initialization_computation( initialize_tree ) ) - tree_analysis.check_contains_only_reducible_intrinsics(initialize_tree) + federated_language.framework.check_contains_only_reducible_intrinsics( + initialize_tree + ) initialize_tree = compiler.consolidate_and_extract_local_processing( initialize_tree, grappler_config ) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=initialize_tree.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) def get_computation_for_map_reduce_form( mrf: forms.MapReduceForm, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates `tff.Computation` from a MapReduce form. Args: @@ -168,26 +168,30 @@ def get_computation_for_map_reduce_form( """ py_typecheck.check_type(mrf, forms.MapReduceForm) - @federated_computation.federated_computation(mrf.type_signature.parameter) + @federated_language.federated_computation(mrf.type_signature.parameter) def computation(arg): """The logic of a single MapReduce processing round.""" server_state, client_data = arg - broadcast_input = intrinsics.federated_map(mrf.prepare, server_state) - broadcast_result = intrinsics.federated_broadcast(broadcast_input) - work_arg = intrinsics.federated_zip([client_data, broadcast_result]) + broadcast_input = federated_language.federated_map( + mrf.prepare, server_state + ) + broadcast_result = federated_language.federated_broadcast(broadcast_input) + work_arg = federated_language.federated_zip([client_data, broadcast_result]) ( aggregate_input, secure_sum_bitwidth_input, secure_sum_input, secure_modular_sum_input, - ) = intrinsics.federated_map(mrf.work, work_arg) - aggregate_result = intrinsics.federated_aggregate( + ) = federated_language.federated_map(mrf.work, work_arg) + aggregate_result = federated_language.federated_aggregate( aggregate_input, mrf.zero(), mrf.accumulate, mrf.merge, mrf.report ) - secure_sum_bitwidth_result = intrinsics.federated_secure_sum_bitwidth( - secure_sum_bitwidth_input, mrf.secure_sum_bitwidth() + secure_sum_bitwidth_result = ( + federated_language.federated_secure_sum_bitwidth( + secure_sum_bitwidth_input, mrf.secure_sum_bitwidth() + ) ) - secure_sum_result = intrinsics.federated_secure_sum( + secure_sum_result = federated_language.federated_secure_sum( secure_sum_input, mrf.secure_sum_max_input() ) secure_modular_sum_result = ( @@ -195,7 +199,7 @@ def computation(arg): secure_modular_sum_input, mrf.secure_modular_sum_modulus() ) ) - update_arg = intrinsics.federated_zip(( + update_arg = federated_language.federated_zip(( server_state, ( aggregate_result, @@ -204,7 +208,7 @@ def computation(arg): secure_modular_sum_result, ), )) - updated_server_state, server_output = intrinsics.federated_map( + updated_server_state, server_output = federated_language.federated_map( mrf.update, update_arg ) return updated_server_state, server_output @@ -214,7 +218,7 @@ def computation(arg): def get_computation_for_distribute_aggregate_form( daf: forms.DistributeAggregateForm, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates `tff.Computation` from a DistributeAggregate form. Args: @@ -228,7 +232,7 @@ def get_computation_for_distribute_aggregate_form( """ py_typecheck.check_type(daf, forms.DistributeAggregateForm) - @federated_computation.federated_computation(daf.type_signature.parameter) + @federated_language.federated_computation(daf.type_signature.parameter) def computation(arg): """The logic of a single federated computation round.""" server_state, client_data = arg @@ -247,18 +251,18 @@ def computation(arg): def _check_type_is_fn( - target: computation_types.Type, + target: federated_language.Type, name: str, err_fn: Callable[[str], Exception] = compiler.MapReduceFormCompilationError, ): - if not isinstance(target, computation_types.FunctionType): + if not isinstance(target, federated_language.FunctionType): raise err_fn( f'Expected {name} to be a function, but {name} had type {target}.' ) def _check_type_is_no_arg_fn( - target: computation_types.Type, + target: federated_language.Type, name: str, err_fn: Callable[[str], Exception] = compiler.MapReduceFormCompilationError, ): @@ -271,12 +275,12 @@ def _check_type_is_no_arg_fn( def _check_function_signature_compatible_with_broadcast_form( - function_type: computation_types.FunctionType, + function_type: federated_language.FunctionType, ): """Tests compatibility with `tff.backends.mapreduce.BroadcastForm`.""" - py_typecheck.check_type(function_type, computation_types.FunctionType) + py_typecheck.check_type(function_type, federated_language.FunctionType) if not ( - isinstance(function_type.parameter, computation_types.StructType) + isinstance(function_type.parameter, federated_language.StructType) and len(function_type.parameter) == 2 ): raise TypeError( @@ -286,8 +290,8 @@ def _check_function_signature_compatible_with_broadcast_form( ) server_data_type, client_data_type = function_type.parameter # pytype: disable=attribute-error if ( - not isinstance(server_data_type, computation_types.FederatedType) - or server_data_type.placement is not placements.SERVER + not isinstance(server_data_type, federated_language.FederatedType) + or server_data_type.placement is not federated_language.SERVER ): raise TypeError( '`BroadcastForm` expects a computation whose first parameter is server ' @@ -295,8 +299,8 @@ def _check_function_signature_compatible_with_broadcast_form( f'type:\n{server_data_type}' ) if ( - not isinstance(client_data_type, computation_types.FederatedType) - or client_data_type.placement is not placements.CLIENTS + not isinstance(client_data_type, federated_language.FederatedType) + or client_data_type.placement is not federated_language.CLIENTS ): raise TypeError( '`BroadcastForm` expects a computation whose first parameter is client ' @@ -305,8 +309,8 @@ def _check_function_signature_compatible_with_broadcast_form( ) result_type = function_type.result if ( - not isinstance(result_type, computation_types.FederatedType) - or result_type.placement is not placements.CLIENTS + not isinstance(result_type, federated_language.FederatedType) + or result_type.placement is not federated_language.CLIENTS ): raise TypeError( '`BroadcastForm` expects a computation whose result is client data ' @@ -316,38 +320,38 @@ def _check_function_signature_compatible_with_broadcast_form( def _check_contains_only_reducible_intrinsics( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, ): """Checks that `comp` contains intrinsics reducible to aggregate or broadcast. Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` to check for - presence of intrinsics not currently immediately reducible to + comp: Instance of `federated_language.framework.ComputationBuildingBlock` to + check for presence of intrinsics not currently immediately reducible to `FEDERATED_AGGREGATE` or `FEDERATED_BROADCAST`, or local processing. Raises: ValueError: If we encounter an intrinsic under `comp` that is not reducible. """ reducible_uris = ( - intrinsic_defs.FEDERATED_AGGREGATE.uri, - intrinsic_defs.FEDERATED_APPLY.uri, - intrinsic_defs.FEDERATED_BROADCAST.uri, - intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri, - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, - intrinsic_defs.FEDERATED_MAP.uri, - intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri, - intrinsic_defs.FEDERATED_SECURE_SUM.uri, - intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, - intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri, + federated_language.framework.FEDERATED_AGGREGATE.uri, + federated_language.framework.FEDERATED_APPLY.uri, + federated_language.framework.FEDERATED_BROADCAST.uri, + federated_language.framework.FEDERATED_EVAL_AT_CLIENTS.uri, + federated_language.framework.FEDERATED_EVAL_AT_SERVER.uri, + federated_language.framework.FEDERATED_MAP_ALL_EQUAL.uri, + federated_language.framework.FEDERATED_MAP.uri, + federated_language.framework.FEDERATED_SECURE_SUM_BITWIDTH.uri, + federated_language.framework.FEDERATED_SECURE_SUM.uri, + federated_language.framework.FEDERATED_VALUE_AT_CLIENTS.uri, + federated_language.framework.FEDERATED_VALUE_AT_SERVER.uri, + federated_language.framework.FEDERATED_ZIP_AT_CLIENTS.uri, + federated_language.framework.FEDERATED_ZIP_AT_SERVER.uri, mapreduce_intrinsics.FEDERATED_SECURE_MODULAR_SUM.uri, ) def _check(comp): if ( - isinstance(comp, building_blocks.Intrinsic) + isinstance(comp, federated_language.framework.Intrinsic) and comp.uri not in reducible_uris ): raise ValueError( @@ -355,14 +359,14 @@ def _check(comp): 'broadcast, the intrinsic {}'.format(comp.compact_representation()) ) - tree_analysis.visit_postorder(comp, _check) + federated_language.framework.visit_postorder(comp, _check) def check_computation_compatible_with_map_reduce_form( - comp: computation_impl.ConcreteComputation, + comp: federated_language.framework.ConcreteComputation, *, tff_internal_preprocessing: Optional[BuildingBlockFn] = None, -) -> building_blocks.ComputationBuildingBlock: +) -> federated_language.framework.ComputationBuildingBlock: """Tests compatibility with `tff.backends.mapreduce.MapReduceForm`. Note: the conditions here are specified in the documentation for @@ -370,8 +374,8 @@ def check_computation_compatible_with_map_reduce_form( be propagated to that documentation. Args: - comp: An instance of `computation_impl.ConcreteComputation` to check for - compatibility with `tff.backends.mapreduce.MapReduceForm`. + comp: An instance of `federated_language.framework.ConcreteComputation` to + check for compatibility with `tff.backends.mapreduce.MapReduceForm`. tff_internal_preprocessing: An optional function to transform the AST of the computation. @@ -382,7 +386,9 @@ def check_computation_compatible_with_map_reduce_form( Raises: TypeError: If the arguments are of the wrong types. """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) + py_typecheck.check_type( + comp, federated_language.framework.ConcreteComputation + ) comp_tree = comp.to_building_block() if tff_internal_preprocessing is not None: comp_tree = tff_internal_preprocessing(comp_tree) @@ -390,7 +396,7 @@ def check_computation_compatible_with_map_reduce_form( comp_type = comp_tree.type_signature _check_type_is_fn(comp_type, '`comp`', TypeError) if ( - not isinstance(comp_type.parameter, computation_types.StructType) + not isinstance(comp_type.parameter, federated_language.StructType) or len(comp_type.parameter) != 2 ): # pytype: disable=attribute-error raise TypeError( @@ -398,7 +404,7 @@ def check_computation_compatible_with_map_reduce_form( f' type:\n{comp_type.parameter}' # pytype: disable=attribute-error ) if ( - not isinstance(comp_type.result, computation_types.StructType) + not isinstance(comp_type.result, federated_language.StructType) or len(comp_type.result) != 2 ): # pytype: disable=attribute-error raise TypeError( @@ -412,7 +418,9 @@ def check_computation_compatible_with_map_reduce_form( comp_tree = _replace_lambda_body_with_call_dominant_form(comp_tree) _check_contains_only_reducible_intrinsics(comp_tree) - tree_analysis.check_broadcast_not_dependent_on_aggregate(comp_tree) + federated_language.framework.check_broadcast_not_dependent_on_aggregate( + comp_tree + ) return comp_tree @@ -421,41 +429,45 @@ def _untuple_broadcast_only_before_after(before, after): """Removes the tuple-ing of the `broadcast` params and results.""" # Since there is only a single intrinsic here, there's no need for the outer # `{intrinsic_name}_param`/`{intrinsic_name}_result` tuples. - untupled_before = building_block_factory.select_output_from_lambda( + untupled_before = federated_language.framework.select_output_from_lambda( before, 'federated_broadcast_param' ) - after_param_name = next(building_block_factory.unique_name_generator(after)) - after_param_type = computation_types.StructType([ + after_param_name = next( + federated_language.framework.unique_name_generator(after) + ) + after_param_type = federated_language.StructType([ ('original_arg', after.parameter_type.original_arg), # pytype: disable=attribute-error ( 'federated_broadcast_result', after.parameter_type.intrinsic_results.federated_broadcast_result, # pytype: disable=attribute-error ), ]) - after_param_ref = building_blocks.Reference( + after_param_ref = federated_language.framework.Reference( after_param_name, after_param_type ) - after_result_arg = building_blocks.Struct([ + after_result_arg = federated_language.framework.Struct([ ( 'original_arg', - building_blocks.Selection(after_param_ref, 'original_arg'), + federated_language.framework.Selection( + after_param_ref, 'original_arg' + ), ), ( 'intrinsic_results', - building_blocks.Struct([( + federated_language.framework.Struct([( 'federated_broadcast_result', - building_blocks.Selection( + federated_language.framework.Selection( after_param_ref, 'federated_broadcast_result' ), )]), ), ]) - after_result = building_blocks.Call( + after_result = federated_language.framework.Call( after, after_result_arg, ) - untupled_after = building_blocks.Lambda( + untupled_after = federated_language.framework.Lambda( after_param_name, after_param_type, after_result, @@ -515,52 +527,54 @@ def _prepare_for_rebinding(bb): def _construct_selection_from_federated_tuple( - federated_tuple: building_blocks.ComputationBuildingBlock, + federated_tuple: federated_language.framework.ComputationBuildingBlock, index: int, name_generator, -) -> building_blocks.ComputationBuildingBlock: +) -> federated_language.framework.ComputationBuildingBlock: """Selects the index `selected_index` from `federated_tuple`.""" if not isinstance( - federated_tuple.type_signature, computation_types.FederatedType + federated_tuple.type_signature, federated_language.FederatedType ): raise ValueError( 'Expected a `tff.FederatedType`, found' f' {federated_tuple.type_signature}.' ) member_type = federated_tuple.type_signature.member - if not isinstance(member_type, computation_types.StructType): + if not isinstance(member_type, federated_language.StructType): raise ValueError(f'Expected a `tff.StructType`, found {member_type}.') param_name = next(name_generator) - selecting_function = building_blocks.Lambda( + selecting_function = federated_language.framework.Lambda( param_name, member_type, - building_blocks.Selection( - building_blocks.Reference(param_name, member_type), + federated_language.framework.Selection( + federated_language.framework.Reference(param_name, member_type), index=index, ), ) - return building_block_factory.create_federated_map_or_apply( + return federated_language.framework.create_federated_map_or_apply( selecting_function, federated_tuple ) def _as_function_of_single_subparameter( - bb: building_blocks.Lambda, index: int -) -> building_blocks.Lambda: + bb: federated_language.framework.Lambda, index: int +) -> federated_language.framework.Lambda: """Turns `x -> ...only uses x_i...` into `x_i -> ...only uses x_i`.""" - tree_analysis.check_has_unique_names(bb) + federated_language.framework.check_has_unique_names(bb) bb = _prepare_for_rebinding(bb) - new_name = next(building_block_factory.unique_name_generator(bb)) - new_ref = building_blocks.Reference( + new_name = next(federated_language.framework.unique_name_generator(bb)) + new_ref = federated_language.framework.Reference( new_name, bb.type_signature.parameter[index] ) new_lambda_body = tree_transformations.replace_selections( bb.result, bb.parameter_name, {(index,): new_ref} ) - new_lambda = building_blocks.Lambda( + new_lambda = federated_language.framework.Lambda( new_ref.name, new_ref.type_signature, new_lambda_body ) - tree_analysis.check_contains_no_new_unbound_references(bb, new_lambda) + federated_language.framework.check_contains_no_new_unbound_references( + bb, new_lambda + ) return new_lambda @@ -573,13 +587,13 @@ class _MismatchedSelectionPlacementError(TypeError): def _as_function_of_some_federated_subparameters( - bb: building_blocks.Lambda, + bb: federated_language.framework.Lambda, paths, -) -> building_blocks.Lambda: +) -> federated_language.framework.Lambda: """Turns `x -> ...only uses parts of x...` into `parts_of_x -> ...`.""" - tree_analysis.check_has_unique_names(bb) + federated_language.framework.check_has_unique_names(bb) bb = _prepare_for_rebinding(bb) - name_generator = building_block_factory.unique_name_generator(bb) + name_generator = federated_language.framework.unique_name_generator(bb) type_list = [] int_paths = [] @@ -587,7 +601,7 @@ def _as_function_of_some_federated_subparameters( selected_type = bb.parameter_type int_path = [] for index in path: - if not isinstance(selected_type, computation_types.StructType): + if not isinstance(selected_type, federated_language.StructType): raise tree_transformations.ParameterSelectionError(path, bb) if isinstance(index, int): if index >= len(selected_type): @@ -599,7 +613,7 @@ def _as_function_of_some_federated_subparameters( raise tree_transformations.ParameterSelectionError(path, bb) int_path.append(structure.name_to_index_map(selected_type)[index]) selected_type = selected_type[index] - if not isinstance(selected_type, computation_types.FederatedType): + if not isinstance(selected_type, federated_language.FederatedType): raise _NonFederatedSelectionError( 'Attempted to rebind references to parameter selection path ' f'{path} from type {bb.parameter_type}, but the value at that path ' @@ -617,10 +631,12 @@ def _as_function_of_some_federated_subparameters( f'have resulted in the list of types:\n{type_list}' ) - zip_type = computation_types.FederatedType( + zip_type = federated_language.FederatedType( [x.member for x in type_list], placement=placement ) - ref_to_zip = building_blocks.Reference(next(name_generator), zip_type) + ref_to_zip = federated_language.framework.Reference( + next(name_generator), zip_type + ) path_to_replacement = {} for i, path in enumerate(int_paths): path_to_replacement[path] = _construct_selection_from_federated_tuple( @@ -630,10 +646,10 @@ def _as_function_of_some_federated_subparameters( new_lambda_body = tree_transformations.replace_selections( bb.result, bb.parameter_name, path_to_replacement ) - lambda_with_zipped_param = building_blocks.Lambda( + lambda_with_zipped_param = federated_language.framework.Lambda( ref_to_zip.name, ref_to_zip.type_signature, new_lambda_body ) - tree_analysis.check_contains_no_new_unbound_references( + federated_language.framework.check_contains_no_new_unbound_references( bb, lambda_with_zipped_param ) @@ -681,13 +697,13 @@ def _extract_prepare(before_broadcast, grappler_config): Args: before_broadcast: The first result of splitting `next_bb` on - `intrinsic_defs.FEDERATED_BROADCAST`. + `federated_language.framework.FEDERATED_BROADCAST`. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization. Returns: `prepare` as specified by `forms.MapReduceForm`, an instance of - `building_blocks.CompiledComputation`. + `federated_language.framework.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong @@ -718,7 +734,7 @@ def _extract_work(before_aggregate, grappler_config): Returns: `work` as specified by `forms.MapReduceForm`, an instance of - `building_blocks.CompiledComputation`. + `federated_language.framework.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong @@ -736,7 +752,7 @@ def _extract_work(before_aggregate, grappler_config): secure_sum_bitwidth_input_index = ('federated_secure_sum_bitwidth_param', 0) secure_sum_input_index = ('federated_secure_sum_param', 0) secure_modular_sum_input_index = ('federated_secure_modular_sum_param', 0) - work_unzipped = building_block_factory.select_output_from_lambda( + work_unzipped = federated_language.framework.select_output_from_lambda( work_to_before_aggregate, [ aggregate_input_index, @@ -745,10 +761,10 @@ def _extract_work(before_aggregate, grappler_config): secure_modular_sum_input_index, ], ) - work = building_blocks.Lambda( + work = federated_language.framework.Lambda( work_unzipped.parameter_name, work_unzipped.parameter_type, - building_block_factory.create_federated_zip(work_unzipped.result), + federated_language.framework.create_federated_zip(work_unzipped.result), ) return compiler.consolidate_and_extract_local_processing( work, grappler_config @@ -756,26 +772,27 @@ def _extract_work(before_aggregate, grappler_config): def _compile_selected_output_to_no_argument_tensorflow( - comp: building_blocks.Lambda, - path: building_block_factory.Path, + comp: federated_language.framework.Lambda, + path: federated_language.framework.Path, grappler_config, -) -> building_blocks.CompiledComputation: +) -> federated_language.framework.CompiledComputation: """Compiles the independent value result of `comp` at `path` to TensorFlow.""" - extracted = building_block_factory.select_output_from_lambda( + extracted = federated_language.framework.select_output_from_lambda( comp, path ).result return compiler.consolidate_and_extract_local_processing( - building_blocks.Lambda(None, None, extracted), grappler_config + federated_language.framework.Lambda(None, None, extracted), + grappler_config, ) def _compile_selected_output_as_tensorflow_function( - comp: building_blocks.Lambda, - path: building_block_factory.Path, + comp: federated_language.framework.Lambda, + path: federated_language.framework.Path, grappler_config, -) -> building_blocks.CompiledComputation: +) -> federated_language.framework.CompiledComputation: """Compiles the functional result of `comp` at `path` to TensorFlow.""" - extracted = building_block_factory.select_output_from_lambda( + extracted = federated_language.framework.select_output_from_lambda( comp, path ).result return compiler.consolidate_and_extract_local_processing( @@ -800,13 +817,13 @@ def _extract_federated_aggregate_functions(before_aggregate, grappler_config): Returns: `zero`, `accumulate`, `merge` and `report` as specified by `forms.MapReduceForm`. All are instances of - `building_blocks.CompiledComputation`. + `federated_language.framework.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an ASTs of the wrong type. """ - federated_aggregate = building_block_factory.select_output_from_lambda( + federated_aggregate = federated_language.framework.select_output_from_lambda( before_aggregate, 'federated_aggregate_param' ) # Index `0` is the value being aggregated. @@ -841,16 +858,16 @@ def _extract_update(after_aggregate, grappler_config): Returns: `update` as specified by `forms.MapReduceForm`, an instance of - `building_blocks.CompiledComputation`. + `federated_language.framework.CompiledComputation`. Raises: compiler.MapReduceFormCompilationError: If we extract an AST of the wrong type. """ - after_aggregate_zipped = building_blocks.Lambda( + after_aggregate_zipped = federated_language.framework.Lambda( after_aggregate.parameter_name, after_aggregate.parameter_type, - building_block_factory.create_federated_zip(after_aggregate.result), + federated_language.framework.create_federated_zip(after_aggregate.result), ) # `create_federated_zip` doesn't have unique reference names, but we need # them for `as_function_of_some_federated_subparameters`. @@ -885,23 +902,23 @@ def _extract_update(after_aggregate, grappler_config): # > into # # unpack = > -> - name_generator = building_block_factory.unique_name_generator( + name_generator = federated_language.framework.unique_name_generator( update_with_flat_inputs ) unpack_param_name = next(name_generator) original_param_type = update_with_flat_inputs.parameter_type.member # pytype: disable=attribute-error - unpack_param_type = computation_types.StructType([ + unpack_param_type = federated_language.StructType([ original_param_type[0], - computation_types.StructType(original_param_type[1:]), + federated_language.StructType(original_param_type[1:]), ]) - unpack_param_ref = building_blocks.Reference( + unpack_param_ref = federated_language.framework.Reference( unpack_param_name, unpack_param_type ) - select = lambda bb, i: building_blocks.Selection(bb, index=i) - unpack = building_blocks.Lambda( + select = lambda bb, i: federated_language.framework.Selection(bb, index=i) + unpack = federated_language.framework.Lambda( unpack_param_name, unpack_param_type, - building_blocks.Struct( + federated_language.framework.Struct( [select(unpack_param_ref, 0)] + [ select(select(unpack_param_ref, 1), i) @@ -912,16 +929,16 @@ def _extract_update(after_aggregate, grappler_config): # update = v -> update_with_flat_inputs(federated_map(unpack, v)) param_name = next(name_generator) - param_type = computation_types.FederatedType( - unpack_param_type, placements.SERVER + param_type = federated_language.FederatedType( + unpack_param_type, federated_language.SERVER ) - param_ref = building_blocks.Reference(param_name, param_type) - update = building_blocks.Lambda( + param_ref = federated_language.framework.Reference(param_name, param_type) + update = federated_language.framework.Lambda( param_name, param_type, - building_blocks.Call( + federated_language.framework.Call( update_with_flat_inputs, - building_block_factory.create_federated_map_or_apply( + federated_language.framework.create_federated_map_or_apply( unpack, param_ref ), ), @@ -932,8 +949,8 @@ def _extract_update(after_aggregate, grappler_config): def _replace_lambda_body_with_call_dominant_form( - comp: building_blocks.Lambda, -) -> building_blocks.Lambda: + comp: federated_language.framework.Lambda, +) -> federated_language.framework.Lambda: """Transforms the body of `comp` to call-dominant form. Call-dominant form ensures that all higher-order functions are fully @@ -944,16 +961,16 @@ def _replace_lambda_body_with_call_dominant_form( intrinsics which will cause that function to fail. Args: - comp: `building_blocks.Lambda` the body of which to convert to call-dominant - form. + comp: `federated_language.framework.Lambda` the body of which to convert to + call-dominant form. Returns: A transformed version of `comp`, whose body is call-dominant. """ transformed = transformations.to_call_dominant(comp) - if not isinstance(transformed, building_blocks.Lambda): - raise building_blocks.UnexpectedBlockError( - building_blocks.Lambda, transformed + if not isinstance(transformed, federated_language.framework.Lambda): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Lambda, transformed ) return transformed @@ -969,7 +986,7 @@ def _merge_grappler_config_with_default( def get_broadcast_form_for_computation( - comp: computation_impl.ConcreteComputation, + comp: federated_language.framework.ConcreteComputation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG, *, tff_internal_preprocessing: Optional[BuildingBlockFn] = None, @@ -977,10 +994,10 @@ def get_broadcast_form_for_computation( """Constructs `tff.backends.mapreduce.BroadcastForm` given a computation. Args: - comp: An instance of `computation_impl.ConcreteComputation` that is - compatible with broadcast form. Computations are only compatible if they - take in a single value placed at server, return a single value placed at - clients, and do not contain any aggregations. + comp: An instance of `federated_language.framework.ConcreteComputation` that + is compatible with broadcast form. Computations are only compatible if + they take in a single value placed at server, return a single value placed + at clients, and do not contain any aggregations. grappler_config: An instance of `tf.compat.v1.ConfigProto` to configure Grappler graph optimization of the Tensorflow graphs backing the resulting `tff.backends.mapreduce.BroadcastForm`. These options are combined with a @@ -995,7 +1012,9 @@ def get_broadcast_form_for_computation( An instance of `tff.backends.mapreduce.BroadcastForm` equivalent to the provided `tff.Computation`. """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) + py_typecheck.check_type( + comp, federated_language.framework.ConcreteComputation + ) _check_function_signature_compatible_with_broadcast_form(comp.type_signature) py_typecheck.check_type(grappler_config, tf.compat.v1.ConfigProto) grappler_config = _merge_grappler_config_with_default(grappler_config) @@ -1006,8 +1025,8 @@ def get_broadcast_form_for_computation( bb, _ = tensorflow_tree_transformations.replace_intrinsics_with_bodies(bb) bb = _replace_lambda_body_with_call_dominant_form(bb) - tree_analysis.check_contains_only_reducible_intrinsics(bb) - aggregations = tree_analysis.find_aggregations_in_tree(bb) + federated_language.framework.check_contains_only_reducible_intrinsics(bb) + aggregations = federated_language.framework.find_aggregations_in_tree(bb) if aggregations: raise ValueError( '`get_broadcast_form_for_computation` called with computation' @@ -1024,9 +1043,9 @@ def get_broadcast_form_for_computation( ) def _create_comp(proto): - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) compute_server_context, client_processing = ( @@ -1047,7 +1066,7 @@ def _create_comp(proto): def get_map_reduce_form_for_computation( - comp: computation_impl.ConcreteComputation, + comp: federated_language.framework.ConcreteComputation, grappler_config: tf.compat.v1.ConfigProto = _GRAPPLER_DEFAULT_CONFIG, *, tff_internal_preprocessing: Optional[BuildingBlockFn] = None, @@ -1078,7 +1097,9 @@ def get_map_reduce_form_for_computation( TypeError: If the arguments are of the wrong types. compiler.MapReduceFormCompilationError: If the compilation process fails. """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) + py_typecheck.check_type( + comp, federated_language.framework.ConcreteComputation + ) comp_bb = check_computation_compatible_with_map_reduce_form( comp, tff_internal_preprocessing=tff_internal_preprocessing ) @@ -1110,9 +1131,9 @@ def get_map_reduce_form_for_computation( update = _extract_update(after_aggregate, grappler_config) def _create_comp(proto): - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) blocks = ( @@ -1132,15 +1153,15 @@ def _create_comp(proto): def get_distribute_aggregate_form_for_computation( - comp: computation_impl.ConcreteComputation, + comp: federated_language.framework.ConcreteComputation, *, tff_internal_preprocessing: Optional[BuildingBlockFn] = None, ) -> forms.DistributeAggregateForm: """Constructs `DistributeAggregateForm` for a computation. Args: - comp: An instance of `computation_impl.ConcreteComputation` that is - compatible with `DistributeAggregateForm`. The computation must take + comp: An instance of `federated_language.framework.ConcreteComputation` that + is compatible with `DistributeAggregateForm`. The computation must take exactly two arguments, and the first must be a state value placed at `SERVER`. The computation must return exactly two values. The type of the first element in the result must also be assignable to the first element @@ -1150,12 +1171,12 @@ def get_distribute_aggregate_form_for_computation( Returns: An instance of `tff.backends.mapreduce.DistributeAggregateForm` equivalent - to the provided `computation_base.Computation`. + to the provided `federated_language.framework.Computation`. Raises: TypeError: If the arguments are of the wrong types. """ - py_typecheck.check_type(comp, computation_base.Computation) + py_typecheck.check_type(comp, federated_language.framework.Computation) # Apply any requested preprocessing to the computation. comp_tree = comp.to_building_block() @@ -1166,7 +1187,7 @@ def get_distribute_aggregate_form_for_computation( comp_type = comp_tree.type_signature _check_type_is_fn(comp_type, '`comp`', TypeError) if ( - not isinstance(comp_type.parameter, computation_types.StructType) + not isinstance(comp_type.parameter, federated_language.StructType) or len(comp_type.parameter) != 2 ): # pytype: disable=attribute-error raise TypeError( @@ -1174,20 +1195,22 @@ def get_distribute_aggregate_form_for_computation( f' type:\n{comp_type.parameter}' # pytype: disable=attribute-error ) if ( - not isinstance(comp_type.result, computation_types.StructType) + not isinstance(comp_type.result, federated_language.StructType) or len(comp_type.result) != 2 ): # pytype: disable=attribute-error raise TypeError( 'Expected `comp` to return two values, found result ' f'type:\n{comp_type.result}' # pytype: disable=attribute-error ) - if not isinstance(comp_tree, building_blocks.Lambda): - raise building_blocks.UnexpectedBlockError( - building_blocks.Lambda, comp_tree + if not isinstance(comp_tree, federated_language.framework.Lambda): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Lambda, comp_tree ) comp_tree = _replace_lambda_body_with_call_dominant_form(comp_tree) comp_tree, _ = tree_transformations.uniquify_reference_names(comp_tree) - tree_analysis.check_broadcast_not_dependent_on_aggregate(comp_tree) + federated_language.framework.check_broadcast_not_dependent_on_aggregate( + comp_tree + ) # To generate the DistributeAggregateForm for the computation, we will split # the computation twice, first on broadcast intrinsics and then on aggregation @@ -1204,18 +1227,24 @@ def get_distribute_aggregate_form_for_computation( # of the split on broadcast intrinsics rather than potentially appearing in # the *last* part of the split on broadcast intrinsics. args_needing_broadcast_dependency = [] - unbound_refs = transformation_utils.get_map_of_unbound_references(comp_tree) + unbound_refs = federated_language.framework.get_map_of_unbound_references( + comp_tree + ) def _find_non_client_placed_args(inner_comp): # Examine the args of the aggregation intrinsic calls. if ( - isinstance(inner_comp, building_blocks.Call) - and isinstance(inner_comp.function, building_blocks.Intrinsic) + isinstance(inner_comp, federated_language.framework.Call) + and isinstance( + inner_comp.function, federated_language.framework.Intrinsic + ) and inner_comp.function.intrinsic_def().aggregation_kind ): aggregation_args = ( inner_comp.argument - if isinstance(inner_comp.argument, building_blocks.Struct) + if isinstance( + inner_comp.argument, federated_language.framework.Struct + ) else [inner_comp.argument] ) unbound_ref_names_for_intrinsic = unbound_refs[inner_comp.argument] @@ -1228,31 +1257,41 @@ def _find_non_client_placed_args(inner_comp): # federated broadcast that depends on it by normalizing it to a # server-placed value. if not isinstance( - aggregation_arg.type_signature, computation_types.FederatedType + aggregation_arg.type_signature, federated_language.FederatedType ): def _has_placement(type_spec): return isinstance( - type_spec.type_signature, computation_types.FederatedType + type_spec.type_signature, federated_language.FederatedType ) - if tree_analysis.count(aggregation_arg, _has_placement) > 0: + if ( + federated_language.framework.computation_count( + aggregation_arg, _has_placement + ) + > 0 + ): raise TypeError( 'DistributeAggregateForm cannot handle an aggregation ' f'intrinsic arg with type {aggregation_arg.type_signature}' ) args_needing_broadcast_dependency.append( - building_block_factory.create_federated_value( - aggregation_arg, placements.SERVER + federated_language.framework.create_federated_value( + aggregation_arg, federated_language.SERVER ) ) - elif aggregation_arg.type_signature.placement == placements.SERVER: + elif ( + aggregation_arg.type_signature.placement + == federated_language.SERVER + ): args_needing_broadcast_dependency.append(aggregation_arg) return inner_comp, True return inner_comp, False - tree_analysis.visit_preorder(comp_tree, _find_non_client_placed_args) + federated_language.framework.visit_preorder( + comp_tree, _find_non_client_placed_args + ) # Add an injected broadcast call to the computation that depends on the # identified non-client-placed args, if any exist. To avoid broadcasting the @@ -1265,16 +1304,18 @@ def _has_placement(type_spec): # here can be drastically simplified. if args_needing_broadcast_dependency: zipped_args_needing_broadcast_dependency = ( - building_block_factory.create_federated_zip( - building_blocks.Struct(args_needing_broadcast_dependency) + federated_language.framework.create_federated_zip( + federated_language.framework.Struct( + args_needing_broadcast_dependency + ) ) ) - injected_broadcast = building_block_factory.create_federated_broadcast( - building_block_factory.create_federated_apply( - building_blocks.Lambda( + injected_broadcast = federated_language.framework.create_federated_broadcast( + federated_language.framework.create_federated_apply( + federated_language.framework.Lambda( 'ignored_param', zipped_args_needing_broadcast_dependency.type_signature.member, - building_blocks.Struct([]), + federated_language.framework.Struct([]), ), zipped_args_needing_broadcast_dependency, ) @@ -1288,17 +1329,17 @@ def _has_placement(type_spec): # does not get pruned by various tree transformations. We will remove this # additional element in the result after the first split operation. revised_block_result = structure.to_elements(comp_tree.result.result) + [ - building_blocks.Reference( + federated_language.framework.Reference( 'injected_broadcast_ref', injected_broadcast.type_signature, ) ] - comp_tree = building_blocks.Lambda( + comp_tree = federated_language.framework.Lambda( comp_tree.parameter_name, comp_tree.parameter_type, - building_blocks.Block( + federated_language.framework.Block( revised_block_locals, - building_blocks.Struct(revised_block_result), + federated_language.framework.Struct(revised_block_result), ), ) @@ -1319,10 +1360,10 @@ def _has_placement(type_spec): server_prepare, server_to_client_broadcast, after_broadcast = ( transformations.divisive_force_align_and_split_by_intrinsics( comp_tree, - intrinsic_defs.get_broadcast_intrinsics(), - before_comp_allowed_original_arg_subparameters=[ - (server_state_index,) - ], + federated_language.framework.get_broadcast_intrinsics(), + before_comp_allowed_original_arg_subparameters=[( + server_state_index, + )], intrinsic_comp_allowed_original_arg_subparameters=[], after_comp_allowed_original_arg_subparameters=[(client_data_index,)], ) @@ -1331,21 +1372,23 @@ def _has_placement(type_spec): # Helper method to replace a lambda with parameter that is a single-element # struct with a lambda that uses the element directly. def _unnest_lambda_parameter(comp): - assert isinstance(comp, building_blocks.Lambda) - assert isinstance(comp.parameter_type, computation_types.StructType) + assert isinstance(comp, federated_language.framework.Lambda) + assert isinstance(comp.parameter_type, federated_language.StructType) - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) new_param_name = next(name_generator) - replacement_ref = building_blocks.Reference( + replacement_ref = federated_language.framework.Reference( new_param_name, comp.parameter_type[0] ) modified_comp_body = tree_transformations.replace_selections( comp.result, comp.parameter_name, {(0,): replacement_ref} ) - modified_comp = building_blocks.Lambda( + modified_comp = federated_language.framework.Lambda( replacement_ref.name, replacement_ref.type_signature, modified_comp_body ) - tree_analysis.check_contains_no_unbound_references(modified_comp) + federated_language.framework.check_contains_no_unbound_references( + modified_comp + ) return modified_comp @@ -1376,22 +1419,24 @@ def _unnest_lambda_parameter(comp): # input that represents the intermediate state that was produced in the first # split. if args_needing_broadcast_dependency: - assert isinstance(after_broadcast.result.result, building_blocks.Struct) + assert isinstance( + after_broadcast.result.result, federated_language.framework.Struct + ) # Check that the last element of the result is the expected empty struct # associated with the injected broadcast call. result_len = len(after_broadcast.result.result) injected_broadcast_result = after_broadcast.result.result[result_len - 1] assert isinstance( injected_broadcast_result.type_signature.member, - computation_types.StructType, + federated_language.StructType, ) assert not injected_broadcast_result.type_signature.member - after_broadcast = building_blocks.Lambda( + after_broadcast = federated_language.framework.Lambda( after_broadcast.parameter_name, after_broadcast.parameter_type, - building_blocks.Block( + federated_language.framework.Block( after_broadcast.result.locals, - building_blocks.Struct( + federated_language.framework.Struct( structure.to_elements(after_broadcast.result.result)[:-1] ), ), @@ -1402,14 +1447,14 @@ def _unnest_lambda_parameter(comp): client_work, client_to_server_aggregation, server_result = ( transformations.divisive_force_align_and_split_by_intrinsics( after_broadcast, - intrinsic_defs.get_aggregation_intrinsics(), + federated_language.framework.get_aggregation_intrinsics(), before_comp_allowed_original_arg_subparameters=[ (client_data_index_in_after_broadcast_param,), (intrinsic_results_index_in_after_broadcast_param,), ], - intrinsic_comp_allowed_original_arg_subparameters=[ - (intermediate_state_index_in_after_broadcast_param,) - ], + intrinsic_comp_allowed_original_arg_subparameters=[( + intermediate_state_index_in_after_broadcast_param, + )], after_comp_allowed_original_arg_subparameters=[ (intermediate_state_index_in_after_broadcast_param,), ], @@ -1419,7 +1464,7 @@ def _unnest_lambda_parameter(comp): # Drop the intermediate_state produced by the second split that is part of # the client_work output. index_of_intrinsic_args_in_client_work_result = 0 - client_work = building_block_factory.select_output_from_lambda( + client_work = federated_language.framework.select_output_from_lambda( client_work, index_of_intrinsic_args_in_client_work_result ) @@ -1445,9 +1490,9 @@ def _unnest_lambda_parameter(comp): ) def _create_comp(proto): - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) comps = [_create_comp(bb.proto) for bb in blocks] diff --git a/tensorflow_federated/python/core/backends/mapreduce/form_utils_test.py b/tensorflow_federated/python/core/backends/mapreduce/form_utils_test.py index e9a3ba93ca..7c9f78fb3d 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/form_utils_test.py +++ b/tensorflow_federated/python/core/backends/mapreduce/form_utils_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf import tree @@ -28,17 +29,7 @@ from tensorflow_federated.python.core.backends.test import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import iterative_process @@ -67,10 +58,10 @@ def get_iterative_process_for_sum_example(): `forms.MapReduceForm`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value([0, 0], placements.SERVER) + return federated_language.federated_value([0, 0], federated_language.SERVER) @tensorflow_computation.tf_computation([np.int32, np.int32]) def prepare(server_state): @@ -89,24 +80,28 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation([ - computation_types.FederatedType([np.int32, np.int32], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" - s2 = intrinsics.federated_map(prepare, server_state) - client_input = intrinsics.federated_broadcast(s2) - c3 = intrinsics.federated_zip([client_data, client_input]) - client_updates = intrinsics.federated_map(work, c3) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_update = intrinsics.federated_secure_sum_bitwidth( + s2 = federated_language.federated_map(prepare, server_state) + client_input = federated_language.federated_broadcast(s2) + c3 = federated_language.federated_zip([client_data, client_input]) + client_updates = federated_language.federated_map(work, c3) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_update = federated_language.federated_secure_sum_bitwidth( client_updates[1], 8 ) - s6 = intrinsics.federated_zip( + s6 = federated_language.federated_zip( [server_state, [unsecure_update, secure_update]] ) - new_server_state, server_output = intrinsics.federated_map(update, s6) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -136,32 +131,38 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation( - computation_types.FederatedType([np.int32, np.int32], placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ) ) def broadcast_and_return_arg_and_result(x): - broadcasted = intrinsics.federated_broadcast(x) + broadcasted = federated_language.federated_broadcast(x) return [broadcasted, x] - @federated_computation.federated_computation([ - computation_types.FederatedType([np.int32, np.int32], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def comp_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" - s2 = intrinsics.federated_map(prepare, server_state) + s2 = federated_language.federated_map(prepare, server_state) unused_client_input, to_broadcast = broadcast_and_return_arg_and_result(s2) - client_input = intrinsics.federated_broadcast(to_broadcast) - c3 = intrinsics.federated_zip([client_data, client_input]) - client_updates = intrinsics.federated_map(work, c3) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_update = intrinsics.federated_secure_sum_bitwidth( + client_input = federated_language.federated_broadcast(to_broadcast) + c3 = federated_language.federated_zip([client_data, client_input]) + client_updates = federated_language.federated_map(work, c3) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_update = federated_language.federated_secure_sum_bitwidth( client_updates[1], 8 ) - s6 = intrinsics.federated_zip( + s6 = federated_language.federated_zip( [server_state, [unsecure_update, secure_update]] ) - new_server_state, server_output = intrinsics.federated_map(update, s6) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return comp_fn @@ -174,10 +175,10 @@ def get_iterative_process_for_sum_example_with_no_prepare(): function before the `federated_broadcast`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value([0, 0], placements.SERVER) + return federated_language.federated_value([0, 0], federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32, [np.int32, np.int32]) def work(client_data, client_input): @@ -192,24 +193,28 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation([ - computation_types.FederatedType([np.int32, np.int32], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" # No call to `federated_map` with a `prepare` function. - client_input = intrinsics.federated_broadcast(server_state) - c3 = intrinsics.federated_zip([client_data, client_input]) - client_updates = intrinsics.federated_map(work, c3) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_update = intrinsics.federated_secure_sum_bitwidth( + client_input = federated_language.federated_broadcast(server_state) + c3 = federated_language.federated_zip([client_data, client_input]) + client_updates = federated_language.federated_map(work, c3) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_update = federated_language.federated_secure_sum_bitwidth( client_updates[1], 8 ) - s6 = intrinsics.federated_zip( + s6 = federated_language.federated_zip( [server_state, [unsecure_update, secure_update]] ) - new_server_state, server_output = intrinsics.federated_map(update, s6) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -223,10 +228,10 @@ def get_iterative_process_for_sum_example_with_no_broadcast(): prepare function before the `federated_broadcast`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value([0, 0], placements.SERVER) + return federated_language.federated_value([0, 0], federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32) def work(client_data): @@ -240,23 +245,27 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation([ - computation_types.FederatedType([np.int32, np.int32], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" # No call to `federated_map` with prepare. # No call to `federated_broadcast`. - client_updates = intrinsics.federated_map(work, client_data) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_update = intrinsics.federated_secure_sum_bitwidth( + client_updates = federated_language.federated_map(work, client_data) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_update = federated_language.federated_secure_sum_bitwidth( client_updates[1], 8 ) - s6 = intrinsics.federated_zip( + s6 = federated_language.federated_zip( [server_state, [unsecure_update, secure_update]] ) - new_server_state, server_output = intrinsics.federated_map(update, s6) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -268,10 +277,10 @@ def get_iterative_process_for_sum_example_with_no_federated_aggregate(): This iterative process does not have a call to `federated_aggregate`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32) def prepare(server_state): @@ -288,20 +297,24 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" - s2 = intrinsics.federated_map(prepare, server_state) - client_input = intrinsics.federated_broadcast(s2) - c3 = intrinsics.federated_zip([client_data, client_input]) - client_updates = intrinsics.federated_map(work, c3) + s2 = federated_language.federated_map(prepare, server_state) + client_input = federated_language.federated_broadcast(s2) + c3 = federated_language.federated_zip([client_data, client_input]) + client_updates = federated_language.federated_map(work, c3) # No call to `federated_aggregate`. - secure_update = intrinsics.federated_secure_sum_bitwidth(client_updates, 8) - s6 = intrinsics.federated_zip([server_state, secure_update]) - new_server_state, server_output = intrinsics.federated_map(update, s6) + secure_update = federated_language.federated_secure_sum_bitwidth( + client_updates, 8 + ) + s6 = federated_language.federated_zip([server_state, secure_update]) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -314,10 +327,10 @@ def get_iterative_process_for_sum_example_with_no_federated_secure_sum_bitwidth( `federated_secure_sum_bitwidth`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32) def prepare(server_state): @@ -334,20 +347,22 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" - s2 = intrinsics.federated_map(prepare, server_state) - client_input = intrinsics.federated_broadcast(s2) - c3 = intrinsics.federated_zip([client_data, client_input]) - client_updates = intrinsics.federated_map(work, c3) - unsecure_update = intrinsics.federated_sum(client_updates) + s2 = federated_language.federated_map(prepare, server_state) + client_input = federated_language.federated_broadcast(s2) + c3 = federated_language.federated_zip([client_data, client_input]) + client_updates = federated_language.federated_map(work, c3) + unsecure_update = federated_language.federated_sum(client_updates) # No call to `federated_secure_sum_bitwidth`. - s6 = intrinsics.federated_zip([server_state, unsecure_update]) - new_server_state, server_output = intrinsics.federated_map(update, s6) + s6 = federated_language.federated_zip([server_state, unsecure_update]) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -360,10 +375,10 @@ def get_iterative_process_for_sum_example_with_no_update(): function before the `federated_broadcast`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value([0, 0], placements.SERVER) + return federated_language.federated_value([0, 0], federated_language.SERVER) @tensorflow_computation.tf_computation([np.int32, np.int32]) def prepare(server_state): @@ -375,25 +390,29 @@ def work(client_data, client_input): del client_input # Unused return 1, 1 - @federated_computation.federated_computation([ - computation_types.FederatedType([np.int32, np.int32], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" - s2 = intrinsics.federated_map(prepare, server_state) - client_input = intrinsics.federated_broadcast(s2) - c3 = intrinsics.federated_zip([client_data, client_input]) - client_updates = intrinsics.federated_map(work, c3) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_update = intrinsics.federated_secure_sum_bitwidth( + s2 = federated_language.federated_map(prepare, server_state) + client_input = federated_language.federated_broadcast(s2) + c3 = federated_language.federated_zip([client_data, client_input]) + client_updates = federated_language.federated_map(work, c3) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_update = federated_language.federated_secure_sum_bitwidth( client_updates[1], 8 ) - new_server_state = intrinsics.federated_zip( + new_server_state = federated_language.federated_zip( [unsecure_update, secure_update] ) # No call to `federated_map` with an `update` function. - server_output = intrinsics.federated_value([], placements.SERVER) + server_output = federated_language.federated_value( + [], federated_language.SERVER + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -409,10 +428,10 @@ def get_iterative_process_for_sum_example_with_no_server_state(): the `federated_broadcast`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value([], placements.SERVER) + return federated_language.federated_value([], federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32) def work(client_data): @@ -423,24 +442,26 @@ def work(client_data): def update(global_update): return global_update - @federated_computation.federated_computation([ - computation_types.FederatedType([], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType([], federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del server_state # Unused # No call to `federated_map` with prepare. # No call to `federated_broadcast`. - client_updates = intrinsics.federated_map(work, client_data) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_update = intrinsics.federated_secure_sum_bitwidth( + client_updates = federated_language.federated_map(work, client_data) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_update = federated_language.federated_secure_sum_bitwidth( client_updates[1], 8 ) - s5 = intrinsics.federated_zip([unsecure_update, secure_update]) + s5 = federated_language.federated_zip([unsecure_update, secure_update]) # Empty server state. - new_server_state = intrinsics.federated_value([], placements.SERVER) - server_output = intrinsics.federated_map(update, s5) + new_server_state = federated_language.federated_value( + [], federated_language.SERVER + ) + server_output = federated_language.federated_map(update, s5) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -454,10 +475,10 @@ def get_iterative_process_for_sum_example_with_no_aggregation(): `forms.MapReduceForm`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" - return intrinsics.federated_value([0, 0], placements.SERVER) + return federated_language.federated_value([0, 0], federated_language.SERVER) @tensorflow_computation.tf_computation( [np.int32, np.int32], [np.int32, np.int32] @@ -466,21 +487,29 @@ def update(server_state, global_update): del server_state # Unused return global_update, [] - @federated_computation.federated_computation([ - computation_types.FederatedType([np.int32, np.int32], placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del client_data # No call to `federated_aggregate`. - unsecure_update = intrinsics.federated_value(1, placements.SERVER) + unsecure_update = federated_language.federated_value( + 1, federated_language.SERVER + ) # No call to `federated_secure_sum_bitwidth`. - secure_update = intrinsics.federated_value(1, placements.SERVER) - s6 = intrinsics.federated_zip( + secure_update = federated_language.federated_value( + 1, federated_language.SERVER + ) + s6 = federated_language.federated_zip( [server_state, [unsecure_update, secure_update]] ) - new_server_state, server_output = intrinsics.federated_map(update, s6) + new_server_state, server_output = federated_language.federated_map( + update, s6 + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -493,34 +522,36 @@ def get_iterative_process_for_minimal_sum_example(): `forms.MapReduceForm`. """ - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): """The `init` function for `tff.templates.IterativeProcess`.""" zero = tensorflow_computation.tf_computation(lambda: [0, 0, 0, 0]) - return intrinsics.federated_eval(zero, placements.SERVER) + return federated_language.federated_eval(zero, federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32) def work(client_data): del client_data # Unused return 1, 1, 1, 1 - @federated_computation.federated_computation([ - computation_types.FederatedType( - [np.int32, np.int32, np.int32, np.int32], placements.SERVER + @federated_language.federated_computation([ + federated_language.FederatedType( + [np.int32, np.int32, np.int32, np.int32], federated_language.SERVER ), - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def next_fn(server_state, client_data): """The `next` function for `tff.templates.IterativeProcess`.""" del server_state # Unused # No call to `federated_map` with prepare. # No call to `federated_broadcast`. - client_updates = intrinsics.federated_map(work, client_data) - unsecure_update = intrinsics.federated_sum(client_updates[0]) - secure_sum_bitwidth_update = intrinsics.federated_secure_sum_bitwidth( - client_updates[1], bitwidth=8 + client_updates = federated_language.federated_map(work, client_data) + unsecure_update = federated_language.federated_sum(client_updates[0]) + secure_sum_bitwidth_update = ( + federated_language.federated_secure_sum_bitwidth( + client_updates[1], bitwidth=8 + ) ) - secure_sum_update = intrinsics.federated_secure_sum( + secure_sum_update = federated_language.federated_secure_sum( client_updates[2], max_input=1 ) secure_modular_sum_update = ( @@ -528,14 +559,16 @@ def next_fn(server_state, client_data): client_updates[3], modulus=8 ) ) - new_server_state = intrinsics.federated_zip([ + new_server_state = federated_language.federated_zip([ unsecure_update, secure_sum_bitwidth_update, secure_sum_update, secure_modular_sum_update, ]) # No call to `federated_map` with an `update` function. - server_output = intrinsics.federated_value([], placements.SERVER) + server_output = federated_language.federated_value( + [], federated_language.SERVER + ) return new_server_state, server_output return iterative_process.IterativeProcess(init_fn, next_fn) @@ -569,21 +602,21 @@ def get_example_cf_compatible_iterative_processes(): def _count_tensorflow_variables_under( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, ) -> int: count_vars = 0 def _count_tensorflow_variables_in( - comp: building_blocks.CompiledComputation, + comp: federated_language.framework.CompiledComputation, ) -> int: """Counts TF Variables in `comp` if `comp` is a TF block.""" if ( - not isinstance(comp, building_blocks.CompiledComputation) + not isinstance(comp, federated_language.framework.CompiledComputation) or comp.proto.WhichOneof('computation') != 'tensorflow' ): raise ValueError( 'Please pass a ' - '`building_blocks.CompiledComputation` of the ' + '`federated_language.framework.CompiledComputation` of the ' '`tensorflow` variety to `count_tensorflow_variables_in`.' ) graph_def = serialization_utils.unpack_graph_def( @@ -613,12 +646,12 @@ def _count_vars_in_function_lib(func_library): def _count_tf_vars(inner_comp): nonlocal count_vars if ( - isinstance(inner_comp, building_blocks.CompiledComputation) + isinstance(inner_comp, federated_language.framework.CompiledComputation) and inner_comp.proto.WhichOneof('computation') == 'tensorflow' ): count_vars += _count_tensorflow_variables_in(inner_comp) - tree_analysis.visit_postorder(comp, _count_tf_vars) + federated_language.framework.visit_postorder(comp, _count_tf_vars) return count_vars @@ -696,12 +729,14 @@ def test_next_computation_returning_tensor_fails_well(self): distribute_aggregate_test_utils.get_temperature_sensor_example().initialize ) init_result = initialize.type_signature.result - lam = building_blocks.Lambda( - 'x', init_result, building_blocks.Reference('x', init_result) + lam = federated_language.framework.Lambda( + 'x', + init_result, + federated_language.framework.Reference('x', init_result), ) - bad_comp = computation_impl.ConcreteComputation( + bad_comp = federated_language.framework.ConcreteComputation( computation_proto=lam.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) with self.assertRaises(TypeError): form_utils.get_distribute_aggregate_form_for_computation(bad_comp) @@ -710,21 +745,21 @@ def test_broadcast_dependent_on_aggregate_fails_well(self): example = distribute_aggregate_test_utils.get_mnist_training_example() comp = form_utils.get_computation_for_distribute_aggregate_form(example.daf) comp_bb = comp.to_building_block() - top_level_param = building_blocks.Reference( + top_level_param = federated_language.framework.Reference( comp_bb.parameter_name, comp_bb.parameter_type ) - first_result = building_blocks.Call(comp_bb, top_level_param) - middle_param = building_blocks.Struct([ - building_blocks.Selection(first_result, index=0), - building_blocks.Selection(top_level_param, index=1), + first_result = federated_language.framework.Call(comp_bb, top_level_param) + middle_param = federated_language.framework.Struct([ + federated_language.framework.Selection(first_result, index=0), + federated_language.framework.Selection(top_level_param, index=1), ]) - second_result = building_blocks.Call(comp_bb, middle_param) - not_reducible = building_blocks.Lambda( + second_result = federated_language.framework.Call(comp_bb, middle_param) + not_reducible = federated_language.framework.Lambda( comp_bb.parameter_name, comp_bb.parameter_type, second_result ) - bad_comp = computation_impl.ConcreteComputation( + bad_comp = federated_language.framework.ConcreteComputation( computation_proto=not_reducible.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): form_utils.get_distribute_aggregate_form_for_computation(bad_comp) @@ -892,20 +927,24 @@ def test_allows_valid_computation(self, ip): def test_disallows_broadcast_dependent_on_aggregate(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType((), placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType((), federated_language.CLIENTS), ) def comp(server_state, client_data): del server_state, client_data - client_val = intrinsics.federated_value(0, placements.CLIENTS) - server_agg = intrinsics.federated_sum(client_val) + client_val = federated_language.federated_value( + 0, federated_language.CLIENTS + ) + server_agg = federated_language.federated_sum(client_val) # This broadcast is dependent on the result of the above aggregation, # which is not supported by MapReduce form. - broadcasted = intrinsics.federated_broadcast(server_agg) - server_agg_again = intrinsics.federated_sum(broadcasted) + broadcasted = federated_language.federated_broadcast(server_agg) + server_agg_again = federated_language.federated_sum(broadcasted) # `next` must return two values. - return server_agg_again, intrinsics.federated_value((), placements.SERVER) + return server_agg_again, federated_language.federated_value( + (), federated_language.SERVER + ) with self.assertRaises(ValueError): form_utils.check_computation_compatible_with_map_reduce_form(comp) @@ -918,12 +957,14 @@ def test_next_computation_returning_tensor_fails_well(self): mapreduce_test_utils.get_temperature_sensor_example().initialize ) init_result = initialize.type_signature.result - lam = building_blocks.Lambda( - 'x', init_result, building_blocks.Reference('x', init_result) + lam = federated_language.framework.Lambda( + 'x', + init_result, + federated_language.framework.Reference('x', init_result), ) - bad_comp = computation_impl.ConcreteComputation( + bad_comp = federated_language.framework.ConcreteComputation( computation_proto=lam.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) with self.assertRaises(TypeError): form_utils.get_map_reduce_form_for_computation(bad_comp) @@ -932,21 +973,21 @@ def test_broadcast_dependent_on_aggregate_fails_well(self): example = mapreduce_test_utils.get_temperature_sensor_example() comp = form_utils.get_computation_for_map_reduce_form(example.mrf) comp_bb = comp.to_building_block() - top_level_param = building_blocks.Reference( + top_level_param = federated_language.framework.Reference( comp_bb.parameter_name, comp_bb.parameter_type ) - first_result = building_blocks.Call(comp_bb, top_level_param) - middle_param = building_blocks.Struct([ - building_blocks.Selection(first_result, index=0), - building_blocks.Selection(top_level_param, index=1), + first_result = federated_language.framework.Call(comp_bb, top_level_param) + middle_param = federated_language.framework.Struct([ + federated_language.framework.Selection(first_result, index=0), + federated_language.framework.Selection(top_level_param, index=1), ]) - second_result = building_blocks.Call(comp_bb, middle_param) - not_reducible = building_blocks.Lambda( + second_result = federated_language.framework.Call(comp_bb, middle_param) + not_reducible = federated_language.framework.Lambda( comp_bb.parameter_name, comp_bb.parameter_type, second_result ) - bad_comp = computation_impl.ConcreteComputation( + bad_comp = federated_language.framework.ConcreteComputation( computation_proto=not_reducible.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): @@ -1084,9 +1125,9 @@ def get_map_reduce_form_for_client_to_server_fn( A `forms.MapReduceForm` which uses the embedded `client_to_server_fn`. """ - @federated_computation.federated_computation([ - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation([ + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def comp_fn(server_state, client_data): server_output = client_to_server_fn(client_data) @@ -1096,13 +1137,13 @@ def comp_fn(server_state, client_data): def test_returns_map_reduce_form_with_secure_sum_bitwidth(self): mrf = self.get_map_reduce_form_for_client_to_server_fn( - lambda data: intrinsics.federated_secure_sum_bitwidth(data, 7) + lambda data: federated_language.federated_secure_sum_bitwidth(data, 7) ) self.assertEqual(mrf.secure_sum_bitwidth(), (7,)) def test_returns_map_reduce_form_with_secure_sum_max_input(self): mrf = self.get_map_reduce_form_for_client_to_server_fn( - lambda data: intrinsics.federated_secure_sum(data, 12) + lambda data: federated_language.federated_secure_sum(data, 12) ) self.assertEqual(mrf.secure_sum_max_input(), (12,)) @@ -1117,40 +1158,44 @@ class BroadcastFormTest(absltest.TestCase): def test_roundtrip(self): add = tensorflow_computation.tf_computation(lambda x, y: x + y) - server_data_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_data_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_data_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + client_data_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - @federated_computation.federated_computation( + @federated_language.federated_computation( server_data_type, client_data_type ) def add_server_number_plus_one(server_number, client_numbers): - one = intrinsics.federated_value(1, placements.SERVER) - server_context = intrinsics.federated_map(add, (one, server_number)) - client_context = intrinsics.federated_broadcast(server_context) - return intrinsics.federated_map(add, (client_context, client_numbers)) + one = federated_language.federated_value(1, federated_language.SERVER) + server_context = federated_language.federated_map( + add, (one, server_number) + ) + client_context = federated_language.federated_broadcast(server_context) + return federated_language.federated_map( + add, (client_context, client_numbers) + ) bf = form_utils.get_broadcast_form_for_computation( add_server_number_plus_one ) self.assertEqual(bf.server_data_label, 'server_number') self.assertEqual(bf.client_data_label, 'client_numbers') - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( bf.compute_server_context.type_signature, - computation_types.FunctionType(np.int32, (np.int32,)), + federated_language.FunctionType(np.int32, (np.int32,)), ) self.assertEqual(2, bf.compute_server_context(1)[0]) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( bf.client_processing.type_signature, - computation_types.FunctionType(((np.int32,), np.int32), np.int32), + federated_language.FunctionType(((np.int32,), np.int32), np.int32), ) self.assertEqual(3, bf.client_processing((1,), 2)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( round_trip_comp.type_signature, add_server_number_plus_one.type_signature, ) @@ -1159,33 +1204,35 @@ def add_server_number_plus_one(server_number, client_numbers): def test_roundtrip_no_broadcast(self): add_five = tensorflow_computation.tf_computation(lambda x: x + 5) - server_data_type = computation_types.FederatedType((), placements.SERVER) - client_data_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + server_data_type = federated_language.FederatedType( + (), federated_language.SERVER + ) + client_data_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - @federated_computation.federated_computation( + @federated_language.federated_computation( server_data_type, client_data_type ) def add_five_at_clients(naught_at_server, client_numbers): del naught_at_server - return intrinsics.federated_map(add_five, client_numbers) + return federated_language.federated_map(add_five, client_numbers) bf = form_utils.get_broadcast_form_for_computation(add_five_at_clients) self.assertEqual(bf.server_data_label, 'naught_at_server') self.assertEqual(bf.client_data_label, 'client_numbers') - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( bf.compute_server_context.type_signature, - computation_types.FunctionType((), ()), + federated_language.FunctionType((), ()), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( bf.client_processing.type_signature, - computation_types.FunctionType(((), np.int32), np.int32), + federated_language.FunctionType(((), np.int32), np.int32), ) self.assertEqual(6, bf.client_processing((), 1)) round_trip_comp = form_utils.get_computation_for_broadcast_form(bf) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( round_trip_comp.type_signature, add_five_at_clients.type_signature ) self.assertEqual([10, 11, 12], round_trip_comp((), [5, 6, 7])) @@ -1196,47 +1243,57 @@ class AsFunctionOfSingleSubparameterTest(absltest.TestCase): def assert_selected_param_to_result_type(self, old_lam, new_lam, index): old_type = old_lam.type_signature new_type = new_lam.type_signature - self.assertIsInstance(old_type, computation_types.FunctionType) - self.assertIsInstance(new_type, computation_types.FunctionType) - type_test_utils.assert_types_equivalent( + self.assertIsInstance(old_type, federated_language.FunctionType) + self.assertIsInstance(new_type, federated_language.FunctionType) + federated_language.framework.assert_types_equivalent( new_type, - computation_types.FunctionType( + federated_language.FunctionType( old_type.parameter[index], old_type.result ), ) def test_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + tuple_of_federated_types = federated_language.StructType( [fed_at_clients, fed_at_server] ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=0, ), ) new_lam = form_utils._as_function_of_single_subparameter(lam, 0) self.assert_selected_param_to_result_type(lam, new_lam, 0) def test_named_element_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType([ + tuple_of_federated_types = federated_language.StructType([ (None, fed_at_server), ('a', fed_at_clients), ]) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), name='a' + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + name='a', ), ) new_lam = form_utils._as_function_of_single_subparameter(lam, 1) @@ -1246,139 +1303,165 @@ def test_named_element_selection(self): class AsFunctionOfSomeSubparametersTest(tf.test.TestCase): def test_raises_on_non_tuple_parameter(self): - lam = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + lam = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) ) with self.assertRaises(tree_transformations.ParameterSelectionError): form_utils._as_function_of_some_federated_subparameters(lam, [(0,)]) def test_raises_on_selection_from_non_tuple(self): - lam = building_blocks.Lambda( - 'x', [np.int32], building_blocks.Reference('x', [np.int32]) + lam = federated_language.framework.Lambda( + 'x', [np.int32], federated_language.framework.Reference('x', [np.int32]) ) with self.assertRaises(tree_transformations.ParameterSelectionError): form_utils._as_function_of_some_federated_subparameters(lam, [(0, 0)]) def test_raises_on_non_federated_selection(self): - lam = building_blocks.Lambda( - 'x', [np.int32], building_blocks.Reference('x', [np.int32]) + lam = federated_language.framework.Lambda( + 'x', [np.int32], federated_language.framework.Reference('x', [np.int32]) ) with self.assertRaises(form_utils._NonFederatedSelectionError): form_utils._as_function_of_some_federated_subparameters(lam, [(0,)]) def test_raises_on_selections_at_different_placements(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + tuple_of_federated_types = federated_language.StructType( [fed_at_clients, fed_at_server] ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=0, ), ) with self.assertRaises(form_utils._MismatchedSelectionPlacementError): form_utils._as_function_of_some_federated_subparameters(lam, [(0,), (1,)]) def test_single_element_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER + ) + tuple_of_federated_types = federated_language.StructType( [fed_at_clients, fed_at_server] ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=0, ), ) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0,)] ) - expected_parameter_type = computation_types.FederatedType( - (np.int32,), placements.CLIENTS + expected_parameter_type = federated_language.FederatedType( + (np.int32,), federated_language.CLIENTS ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( new_lam.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( expected_parameter_type, lam.result.type_signature ), ) def test_single_named_element_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + tuple_of_federated_types = federated_language.StructType( [('a', fed_at_clients), ('b', fed_at_server)] ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), name='a' + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + name='a', ), ) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0,)] ) - expected_parameter_type = computation_types.FederatedType( - (np.int32,), placements.CLIENTS + expected_parameter_type = federated_language.FederatedType( + (np.int32,), federated_language.CLIENTS ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( new_lam.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( expected_parameter_type, lam.result.type_signature ), ) def test_single_element_selection_leaves_no_unbound_references(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + tuple_of_federated_types = federated_language.StructType( [fed_at_clients, fed_at_server] ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=0, ), ) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0,)] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_lam - )[new_lam] + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_lam)[ + new_lam + ] + ) self.assertEmpty(unbound_references) def test_single_nested_element_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER + ) + tuple_of_federated_types = federated_language.StructType( [[fed_at_clients], fed_at_server] ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), index=0, ), index=0, @@ -1388,127 +1471,157 @@ def test_single_nested_element_selection(self): new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0)] ) - expected_parameter_type = computation_types.FederatedType( - (np.int32,), placements.CLIENTS + expected_parameter_type = federated_language.FederatedType( + (np.int32,), federated_language.CLIENTS ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( new_lam.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( expected_parameter_type, lam.result.type_signature ), ) def test_multiple_nested_element_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + tuple_of_federated_types = federated_language.StructType( [[fed_at_clients], fed_at_server, [fed_at_clients]] ) - first_selection = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=0 + first_selection = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=0, ), index=0, ) - second_selection = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=2 + second_selection = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=2, ), index=0, ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Struct([first_selection, second_selection]), + federated_language.framework.Struct( + [first_selection, second_selection] + ), ) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0), (2, 0)] ) - expected_parameter_type = computation_types.FederatedType( - (np.int32, np.int32), placements.CLIENTS + expected_parameter_type = federated_language.FederatedType( + (np.int32, np.int32), federated_language.CLIENTS ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( new_lam.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( expected_parameter_type, lam.result.type_signature ), ) def test_multiple_nested_named_element_selection(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType([ + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER + ) + tuple_of_federated_types = federated_language.StructType([ ('a', [('a', fed_at_clients)]), ('b', fed_at_server), ('c', [('c', fed_at_clients)]), ]) - first_selection = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), name='a' + first_selection = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + name='a', ), name='a', ) - second_selection = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), name='c' + second_selection = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + name='c', ), name='c', ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Struct([first_selection, second_selection]), + federated_language.framework.Struct( + [first_selection, second_selection] + ), ) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0), (2, 0)] ) - expected_parameter_type = computation_types.FederatedType( - (np.int32, np.int32), placements.CLIENTS + expected_parameter_type = federated_language.FederatedType( + (np.int32, np.int32), federated_language.CLIENTS ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( new_lam.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( expected_parameter_type, lam.result.type_signature ), ) def test_binding_multiple_args_results_in_unique_names(self): - fed_at_clients = computation_types.FederatedType( - np.int32, placements.CLIENTS + fed_at_clients = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - fed_at_server = computation_types.FederatedType(np.int32, placements.SERVER) - tuple_of_federated_types = computation_types.StructType( + fed_at_server = federated_language.FederatedType( + np.int32, federated_language.SERVER + ) + tuple_of_federated_types = federated_language.StructType( [[fed_at_clients], fed_at_server, [fed_at_clients]] ) - first_selection = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=0 + first_selection = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=0, ), index=0, ) - second_selection = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', tuple_of_federated_types), index=2 + second_selection = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', tuple_of_federated_types + ), + index=2, ), index=0, ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', tuple_of_federated_types, - building_blocks.Struct([first_selection, second_selection]), + federated_language.framework.Struct( + [first_selection, second_selection] + ), ) new_lam = form_utils._as_function_of_some_federated_subparameters( lam, [(0, 0), (2, 0)] ) - tree_analysis.check_has_unique_names(new_lam) + federated_language.framework.check_has_unique_names(new_lam) if __name__ == '__main__': diff --git a/tensorflow_federated/python/core/backends/mapreduce/forms.py b/tensorflow_federated/python/core/backends/mapreduce/forms.py index c4915f895d..a71985a228 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/forms.py +++ b/tensorflow_federated/python/core/backends/mapreduce/forms.py @@ -16,21 +16,17 @@ from collections.abc import Callable from typing import Optional +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import typed_object def _check_tensorflow_computation(label, comp): - py_typecheck.check_type(comp, computation_impl.ConcreteComputation, label) - comp_proto = computation_impl.ConcreteComputation.get_proto(comp) + py_typecheck.check_type( + comp, federated_language.framework.ConcreteComputation, label + ) + comp_proto = federated_language.framework.ConcreteComputation.get_proto(comp) which_comp = comp_proto.WhichOneof('computation') if which_comp != 'tensorflow': raise TypeError( @@ -40,41 +36,49 @@ def _check_tensorflow_computation(label, comp): def _check_lambda_computation(label, comp): - py_typecheck.check_type(comp, computation_impl.ConcreteComputation, label) - comp_proto = computation_impl.ConcreteComputation.get_proto(comp) + py_typecheck.check_type( + comp, federated_language.framework.ConcreteComputation, label + ) + comp_proto = federated_language.framework.ConcreteComputation.get_proto(comp) which_comp = comp_proto.WhichOneof('computation') if which_comp != 'lambda': raise TypeError( 'Expected all computations supplied as arguments to ' 'be Lambda computations, found {}.'.format(which_comp) ) - tree_analysis.check_contains_no_unbound_references(comp.to_building_block()) - tree_analysis.check_has_unique_names(comp.to_building_block()) + federated_language.framework.check_contains_no_unbound_references( + comp.to_building_block() + ) + federated_language.framework.check_has_unique_names(comp.to_building_block()) def _check_flattened_intrinsic_args_are_selections_or_literals( - value: building_blocks.ComputationBuildingBlock, + value: federated_language.framework.ComputationBuildingBlock, expected_reference_name: str, ): """Checks the flattened args of an intrinsic are Selections or Literals.""" - if isinstance(value, building_blocks.Struct): + if isinstance(value, federated_language.framework.Struct): inner_values = structure.flatten(value) else: inner_values = [value] for inner_value in inner_values: if not isinstance( - inner_value, (building_blocks.Literal, building_blocks.Selection) + inner_value, + ( + federated_language.framework.Literal, + federated_language.framework.Selection, + ), ): raise TypeError( 'Expected that all arguments to an intrinsic call are selections or' ' literals or structs containing only selections or literals, found' f' {type(inner_value)}.' ) - if isinstance(inner_value, building_blocks.Selection): + if isinstance(inner_value, federated_language.framework.Selection): source = inner_value.source if ( - isinstance(source, building_blocks.Reference) + isinstance(source, federated_language.framework.Reference) and source.name != expected_reference_name ): raise TypeError( @@ -90,15 +94,15 @@ def _is_assignable_from_or_both_none(first, second): return first.is_assignable_from(second) -def _is_tuple(type_signature: computation_types.Type, length: int) -> bool: +def _is_tuple(type_signature: federated_language.Type, length: int) -> bool: return ( - isinstance(type_signature, computation_types.StructType) + isinstance(type_signature, federated_language.StructType) and len(type_signature) == length ) def _check_accepts_tuple( - label: str, comp: computation_base.Computation, length: int + label: str, comp: federated_language.framework.Computation, length: int ): param_type = comp.type_signature.parameter if not _is_tuple(param_type, length): @@ -109,7 +113,7 @@ def _check_accepts_tuple( def _check_returns_tuple( - label: str, comp: computation_base.Computation, length: int + label: str, comp: federated_language.framework.Computation, length: int ): result_type = comp.type_signature.result if not _is_tuple(result_type, length): @@ -220,7 +224,7 @@ def summary(self, print_fn=print): WORK_RESULT_LEN = 4 -class MapReduceForm(typed_object.TypedObject): +class MapReduceForm(federated_language.TypedObject): """Standardized representation of logic deployable to MapReduce-like systems. This class docstring describes the purpose of `MapReduceForm` as a data @@ -273,17 +277,17 @@ class MapReduceForm(typed_object.TypedObject): def __init__( self, - type_signature: computation_types.FunctionType, - prepare: computation_impl.ConcreteComputation, - work: computation_impl.ConcreteComputation, - zero: computation_impl.ConcreteComputation, - accumulate: computation_impl.ConcreteComputation, - merge: computation_impl.ConcreteComputation, - report: computation_impl.ConcreteComputation, - secure_sum_bitwidth: computation_impl.ConcreteComputation, - secure_sum_max_input: computation_impl.ConcreteComputation, - secure_modular_sum_modulus: computation_impl.ConcreteComputation, - update: computation_impl.ConcreteComputation, + type_signature: federated_language.FunctionType, + prepare: federated_language.framework.ConcreteComputation, + work: federated_language.framework.ConcreteComputation, + zero: federated_language.framework.ConcreteComputation, + accumulate: federated_language.framework.ConcreteComputation, + merge: federated_language.framework.ConcreteComputation, + report: federated_language.framework.ConcreteComputation, + secure_sum_bitwidth: federated_language.framework.ConcreteComputation, + secure_sum_max_input: federated_language.framework.ConcreteComputation, + secure_modular_sum_modulus: federated_language.framework.ConcreteComputation, + update: federated_language.framework.ConcreteComputation, ): """Constructs a representation of a MapReduce-like computation. @@ -352,7 +356,7 @@ def __init__( if ( isinstance( - accumulate.type_signature.parameter, computation_types.StructType + accumulate.type_signature.parameter, federated_language.StructType ) and len(accumulate.type_signature.parameter) != 2 ): @@ -381,7 +385,9 @@ def __init__( ) # pytype: disable=unsupported-operands if ( - isinstance(merge.type_signature.parameter, computation_types.StructType) + isinstance( + merge.type_signature.parameter, federated_language.StructType + ) and len(merge.type_signature.parameter) != 2 ): raise ValueError( @@ -403,7 +409,7 @@ def __init__( merge.type_signature.result ) # pytype: disable=attribute-error - expected_update_parameter_type = computation_types.to_type([ + expected_update_parameter_type = federated_language.to_type([ type_signature.parameter[0].member, # pytype: disable=unsupported-operands [ report.type_signature.result, @@ -418,14 +424,14 @@ def __init__( # input. Verifying it aligns with a tff.Computation that produces an initial # state should be verified outside of the constructor of the MapReduceForm. if not _is_assignable_from_or_both_none( - computation_types.to_type(update.type_signature.parameter), + federated_language.to_type(update.type_signature.parameter), expected_update_parameter_type, ): raise TypeError( 'The `update` computation expects arguments of type {}, ' 'which does not match the expected {} as implied by the type ' 'signatures of `report` and `work`.'.format( - computation_types.to_type(update.type_signature.parameter[1:]), # pytype: disable=unsupported-operands + federated_language.to_type(update.type_signature.parameter[1:]), # pytype: disable=unsupported-operands expected_update_parameter_type, ) ) @@ -458,48 +464,54 @@ def __init__( self._server_state_label, self._client_data_label = parameter_names @property - def type_signature(self) -> computation_types.FunctionType: + def type_signature(self) -> federated_language.FunctionType: """Returns the TFF type of the equivalent `tff.Computation`.""" return self._type_signature @property - def prepare(self) -> computation_impl.ConcreteComputation: + def prepare(self) -> federated_language.framework.ConcreteComputation: return self._prepare @property - def work(self) -> computation_impl.ConcreteComputation: + def work(self) -> federated_language.framework.ConcreteComputation: return self._work @property - def zero(self) -> computation_impl.ConcreteComputation: + def zero(self) -> federated_language.framework.ConcreteComputation: return self._zero @property - def accumulate(self) -> computation_impl.ConcreteComputation: + def accumulate(self) -> federated_language.framework.ConcreteComputation: return self._accumulate @property - def merge(self) -> computation_impl.ConcreteComputation: + def merge(self) -> federated_language.framework.ConcreteComputation: return self._merge @property - def report(self) -> computation_impl.ConcreteComputation: + def report(self) -> federated_language.framework.ConcreteComputation: return self._report @property - def secure_sum_bitwidth(self) -> computation_impl.ConcreteComputation: + def secure_sum_bitwidth( + self, + ) -> federated_language.framework.ConcreteComputation: return self._secure_sum_bitwidth @property - def secure_sum_max_input(self) -> computation_impl.ConcreteComputation: + def secure_sum_max_input( + self, + ) -> federated_language.framework.ConcreteComputation: return self._secure_sum_max_input @property - def secure_modular_sum_modulus(self) -> computation_impl.ConcreteComputation: + def secure_modular_sum_modulus( + self, + ) -> federated_language.framework.ConcreteComputation: return self._secure_modular_sum_modulus @property - def update(self) -> computation_impl.ConcreteComputation: + def update(self) -> federated_language.framework.ConcreteComputation: return self._update @property @@ -523,7 +535,7 @@ def securely_aggregates_tensors(self) -> bool: secagg_max_input_type, secagg_modulus_type, ]: - if type_analysis.contains_tensor_types(secagg_type): + if federated_language.framework.contains_tensor_types(secagg_type): return True return False @@ -555,7 +567,7 @@ def summary(self, print_fn: Callable[..., None] = print) -> None: ) -class DistributeAggregateForm(typed_object.TypedObject): +class DistributeAggregateForm(federated_language.TypedObject): """Standard representation of logic deployable to a federated learning system. This class docstring describes the purpose of `DistributeAggregateForm` as a @@ -586,12 +598,12 @@ class DistributeAggregateForm(typed_object.TypedObject): def __init__( self, - type_signature: computation_types.FunctionType, - server_prepare: computation_impl.ConcreteComputation, - server_to_client_broadcast: computation_impl.ConcreteComputation, - client_work: computation_impl.ConcreteComputation, - client_to_server_aggregation: computation_impl.ConcreteComputation, - server_result: computation_impl.ConcreteComputation, + type_signature: federated_language.FunctionType, + server_prepare: federated_language.framework.ConcreteComputation, + server_to_client_broadcast: federated_language.framework.ConcreteComputation, + client_work: federated_language.framework.ConcreteComputation, + client_to_server_aggregation: federated_language.framework.ConcreteComputation, + server_result: federated_language.framework.ConcreteComputation, ): """Constructs a representation of a round for a federated learning system. @@ -628,8 +640,8 @@ def __init__( # represents the server state and produce 2 results (data to broadcast and # temporary state). It should contain only server placements. _check_returns_tuple('server_prepare', server_prepare, length=2) - tree_analysis.check_has_single_placement( - server_prepare.to_building_block(), placements.SERVER + federated_language.framework.check_has_single_placement( + server_prepare.to_building_block(), federated_language.SERVER ) # The broadcast function can take an arbitrary number of inputs and produce @@ -641,12 +653,12 @@ def __init__( local_name, local_value, ) in server_to_client_broadcast.to_building_block().result.locals: # pytype: disable=attribute-error - if not isinstance(local_value, building_blocks.Call): + if not isinstance(local_value, federated_language.framework.Call): raise ValueError( f'Expected a `tff.framework.Call`, found {type(local_value)}.' ) local_fn = local_value.function - if not isinstance(local_fn, building_blocks.Intrinsic): + if not isinstance(local_fn, federated_language.framework.Intrinsic): raise ValueError( f'Expected a `tff.framework.Intrinsic`, found {type(local_fn)}.' ) @@ -663,7 +675,7 @@ def __init__( expected_return_references.append(local_name) if not isinstance( server_to_client_broadcast.to_building_block().result.result, # pytype: disable=attribute-error - building_blocks.Struct, + federated_language.framework.Struct, ): raise ValueError( 'Expected a `tff.framework.Struct`, found' @@ -683,8 +695,8 @@ def __init__( # data) and produce an output of arbitrary length that represents the data # to aggregate. It should contain only CLIENTS placements. _check_accepts_tuple('client_work', client_work, length=2) - tree_analysis.check_has_single_placement( - client_work.to_building_block(), placements.CLIENTS + federated_language.framework.check_has_single_placement( + client_work.to_building_block(), federated_language.CLIENTS ) # The client_to_server_aggregation function should take 2 inputs (temporary @@ -700,12 +712,12 @@ def __init__( local_name, local_value, ) in client_to_server_aggregation.to_building_block().result.locals: # pytype: disable=attribute-error - if not isinstance(local_value, building_blocks.Call): + if not isinstance(local_value, federated_language.framework.Call): raise ValueError( f'Expected a `tff.framework.Call`, found {type(local_value)}.' ) local_fn = local_value.function - if not isinstance(local_fn, building_blocks.Intrinsic): + if not isinstance(local_fn, federated_language.framework.Intrinsic): raise ValueError( f'Expected a `tff.framework.Intrinsic`, found {type(local_fn)}.' ) @@ -724,7 +736,9 @@ def __init__( aggregation_result_result = ( client_to_server_aggregation.to_building_block().result.result # pytype: disable=attribute-error ) - if not isinstance(aggregation_result_result, building_blocks.Struct): + if not isinstance( + aggregation_result_result, federated_language.framework.Struct + ): raise ValueError( 'Expected a `tff.framework.Struct`, found' f' {type(aggregation_result_result)}.' @@ -743,8 +757,8 @@ def __init__( # output). It should contain only SERVER placements. _check_accepts_tuple('server_result', server_result, length=2) _check_returns_tuple('server_result', server_result, length=2) - tree_analysis.check_has_single_placement( - server_result.to_building_block(), placements.SERVER + federated_language.framework.check_has_single_placement( + server_result.to_building_block(), federated_language.SERVER ) # The broadcast input data types in the 'server_prepare' result and @@ -895,30 +909,32 @@ def __init__( self._server_result = server_result @property - def type_signature(self) -> computation_types.FunctionType: + def type_signature(self) -> federated_language.FunctionType: """Returns the TFF type of the equivalent `tff.Computation`.""" return self._type_signature @property - def server_prepare(self) -> computation_impl.ConcreteComputation: + def server_prepare(self) -> federated_language.framework.ConcreteComputation: return self._server_prepare @property - def server_to_client_broadcast(self) -> computation_impl.ConcreteComputation: + def server_to_client_broadcast( + self, + ) -> federated_language.framework.ConcreteComputation: return self._server_to_client_broadcast @property - def client_work(self) -> computation_impl.ConcreteComputation: + def client_work(self) -> federated_language.framework.ConcreteComputation: return self._client_work @property def client_to_server_aggregation( self, - ) -> computation_impl.ConcreteComputation: + ) -> federated_language.framework.ConcreteComputation: return self._client_to_server_aggregation @property - def server_result(self) -> computation_impl.ConcreteComputation: + def server_result(self) -> federated_language.framework.ConcreteComputation: return self._server_result def summary(self, print_fn: Callable[..., None] = print) -> None: diff --git a/tensorflow_federated/python/core/backends/mapreduce/forms_test.py b/tensorflow_federated/python/core/backends/mapreduce/forms_test.py index 1baef5bafd..fcb93eaec1 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/forms_test.py +++ b/tensorflow_federated/python/core/backends/mapreduce/forms_test.py @@ -13,6 +13,7 @@ # limitations under the License. from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf @@ -20,16 +21,12 @@ from tensorflow_federated.python.core.backends.mapreduce import forms from tensorflow_federated.python.core.backends.mapreduce import mapreduce_test_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _test_broadcast_form_computations(): server_data_type = (np.int32, np.int32) context_type = np.int32 - client_data_type = computation_types.SequenceType(np.float32) + client_data_type = federated_language.SequenceType(np.float32) @tensorflow_computation.tf_computation(server_data_type) def compute_server_context(server_data): @@ -53,7 +50,7 @@ def prepare(server_state): return 1.0 @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.float32), np.float32 + federated_language.SequenceType(np.float32), np.float32 ) def work(client_data, client_input): del client_data # Unused @@ -87,7 +84,7 @@ def report(accumulator): bitwidth = unit_comp max_input = unit_comp modulus = unit_comp - unit_type = computation_types.to_type([]) + unit_type = federated_language.to_type([]) @tensorflow_computation.tf_computation( np.int32, (np.float32, unit_type, unit_type, unit_type) @@ -159,8 +156,8 @@ def _build_test_map_reduce_form_with_computations( def _test_distribute_aggregate_form_computations(): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def server_prepare(server_state): @tensorflow_computation.tf_computation @@ -171,25 +168,28 @@ def server_prepare_broadcast_tf(): def server_prepare_state_tf(): return 32 - return [ - [ - intrinsics.federated_value( - server_prepare_broadcast_tf(), placements.SERVER - ) - ] - ], [server_prepare_state_tf(), server_state] + return [[ + federated_language.federated_value( + server_prepare_broadcast_tf(), federated_language.SERVER + ) + ]], [server_prepare_state_tf(), server_state] - @federated_computation.federated_computation( - [[computation_types.FederatedType(np.float32, placements.SERVER)]] - ) + @federated_language.federated_computation([[ + federated_language.FederatedType(np.float32, federated_language.SERVER) + ]]) def server_to_client_broadcast(context_at_server): - return [intrinsics.federated_broadcast(context_at_server[0][0])] + return [federated_language.federated_broadcast(context_at_server[0][0])] - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.SequenceType(np.float32), placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.SequenceType(np.float32), + federated_language.CLIENTS, ), - [computation_types.FederatedType(np.float32, placements.CLIENTS)], + [ + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ) + ], ) def client_work(client_data, context_at_clients): @tensorflow_computation.tf_computation @@ -198,32 +198,48 @@ def client_work_tf(): del client_data # Unused del context_at_clients # Unused - return [[intrinsics.federated_value(client_work_tf(), placements.CLIENTS)]] + return [[ + federated_language.federated_value( + client_work_tf(), federated_language.CLIENTS + ) + ]] - @federated_computation.federated_computation( - [np.int32, computation_types.FederatedType(np.int32, placements.SERVER)], + @federated_language.federated_computation( + [ + np.int32, + federated_language.FederatedType(np.int32, federated_language.SERVER), + ], [[ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.CLIENTS + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.CLIENTS, ) ]], ) def client_to_server_aggregation(temp_server_state, client_updates): del temp_server_state # Unused. - return [intrinsics.federated_secure_sum_bitwidth(client_updates[0][0], 100)] + return [ + federated_language.federated_secure_sum_bitwidth( + client_updates[0][0], 100 + ) + ] - @federated_computation.federated_computation( - [np.int32, computation_types.FederatedType(np.int32, placements.SERVER)], + @federated_language.federated_computation( + [ + np.int32, + federated_language.FederatedType(np.int32, federated_language.SERVER), + ], [ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.SERVER + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.SERVER, ) ], ) def server_result(temp_server_state, aggregated_results): del aggregated_results # Unused - return temp_server_state[1], intrinsics.federated_value( - [], placements.SERVER + return temp_server_state[1], federated_language.federated_value( + [], federated_language.SERVER ) return ( @@ -308,7 +324,7 @@ def test_init_does_not_raise_type_error(self): self.fail('Raised TypeError unexpectedly.') def test_init_does_not_raise_type_error_with_unknown_dimensions(self): - server_state_type = computation_types.TensorType(np.int32, [None]) + server_state_type = federated_language.TensorType(np.int32, [None]) @tensorflow_computation.tf_computation(server_state_type) def prepare(server_state): @@ -316,7 +332,7 @@ def prepare(server_state): return 1.0 @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.float32), np.float32 + federated_language.SequenceType(np.float32), np.float32 ) def work(client_data, client_input): del client_data # Unused @@ -328,7 +344,7 @@ def zero(): return tf.constant([], dtype=tf.string) @tensorflow_computation.tf_computation( - computation_types.TensorType(np.str_, [None]), np.bool_ + federated_language.TensorType(np.str_, [None]), np.bool_ ) def accumulate(accumulator, client_update): del accumulator # Unused @@ -336,8 +352,8 @@ def accumulate(accumulator, client_update): return tf.constant(['abc']) @tensorflow_computation.tf_computation( - computation_types.TensorType(np.str_, [None]), - computation_types.TensorType(np.str_, [None]), + federated_language.TensorType(np.str_, [None]), + federated_language.TensorType(np.str_, [None]), ) def merge(accumulator1, accumulator2): del accumulator1 # Unused @@ -345,7 +361,7 @@ def merge(accumulator1, accumulator2): return tf.constant(['abc']) @tensorflow_computation.tf_computation( - computation_types.TensorType(np.str_, [None]) + federated_language.TensorType(np.str_, [None]) ) def report(accumulator): del accumulator # Unused @@ -355,7 +371,7 @@ def report(accumulator): bitwidth = unit_comp max_input = unit_comp modulus = unit_comp - unit_type = computation_types.to_type([]) + unit_type = federated_language.to_type([]) @tensorflow_computation.tf_computation( server_state_type, (np.float32, unit_type, unit_type, unit_type) @@ -409,8 +425,9 @@ def prepare(server_state): _build_test_map_reduce_form_with_computations(prepare=prepare) def test_init_raises_type_error_with_bad_work_second_parameter_type(self): + @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.float32), np.int32 + federated_language.SequenceType(np.float32), np.int32 ) def work(client_data, client_input): del client_data # Unused @@ -421,8 +438,9 @@ def work(client_data, client_input): _build_test_map_reduce_form_with_computations(work=work) def test_init_raises_type_error_with_bad_work_result_type(self): + @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.float32), np.float32 + federated_language.SequenceType(np.float32), np.float32 ) def work(client_data, client_input): del client_data # Unused @@ -534,8 +552,9 @@ def report(accumulator): _build_test_map_reduce_form_with_computations(report=report) def test_init_raises_type_error_with_bad_update_first_parameter_type(self): + @tensorflow_computation.tf_computation( - np.float32, (np.float32, computation_types.StructType([])) + np.float32, (np.float32, federated_language.StructType([])) ) def update(server_state, global_update): del server_state # Unused @@ -546,8 +565,9 @@ def update(server_state, global_update): _build_test_map_reduce_form_with_computations(update=update) def test_init_raises_type_error_with_bad_update_second_parameter_type(self): + @tensorflow_computation.tf_computation( - np.int32, (np.int32, computation_types.StructType([])) + np.int32, (np.int32, federated_language.StructType([])) ) def update(server_state, global_update): del server_state # Unused @@ -558,8 +578,9 @@ def update(server_state, global_update): _build_test_map_reduce_form_with_computations(update=update) def test_init_raises_type_error_with_bad_update_result_type(self): + @tensorflow_computation.tf_computation( - np.int32, (np.float32, computation_types.StructType([])) + np.int32, (np.float32, federated_language.StructType([])) ) def update(server_state, global_update): del server_state # Unused @@ -618,43 +639,52 @@ def test_init_does_not_raise_type_error(self): except TypeError: self.fail('Raised TypeError unexpectedly.') assert daf.type_signature.is_equivalent_to( - computation_types.FunctionType( - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - computation_types.SequenceType(np.float32), - placements.CLIENTS, + federated_language.FunctionType( + federated_language.StructType([ + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + federated_language.FederatedType( + federated_language.SequenceType(np.float32), + federated_language.CLIENTS, ), ]), - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType([], placements.SERVER), + federated_language.StructType([ + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + federated_language.FederatedType([], federated_language.SERVER), ]), ) ) def test_init_does_not_raise_type_error_with_unknown_dimensions(self): - state_type = computation_types.TensorType(np.int32, [None]) + state_type = federated_language.TensorType(np.int32, [None]) - @federated_computation.federated_computation( - computation_types.FederatedType(state_type, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(state_type, federated_language.SERVER) ) def server_prepare(server_state): return [[ server_state, ]], [server_state] - @federated_computation.federated_computation( - [[computation_types.FederatedType(state_type, placements.SERVER)]] - ) + @federated_language.federated_computation([[ + federated_language.FederatedType(state_type, federated_language.SERVER) + ]]) def server_to_client_broadcast(context_at_server): - return [intrinsics.federated_broadcast(context_at_server[0][0])] + return [federated_language.federated_broadcast(context_at_server[0][0])] - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.SequenceType(np.float32), placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.SequenceType(np.float32), + federated_language.CLIENTS, ), - [computation_types.FederatedType(state_type, placements.CLIENTS)], + [ + federated_language.FederatedType( + state_type, federated_language.CLIENTS + ) + ], ) def client_work(client_data, context_at_clients): @tensorflow_computation.tf_computation @@ -663,21 +693,35 @@ def client_work_tf(): del client_data # Unused del context_at_clients # Unused - return [ - [intrinsics.federated_value(client_work_tf(), placements.CLIENTS)] - ] + return [[ + federated_language.federated_value( + client_work_tf(), federated_language.CLIENTS + ) + ]] - @federated_computation.federated_computation( - [computation_types.FederatedType(state_type, placements.SERVER)], - [[computation_types.FederatedType(np.int32, placements.CLIENTS)]], + @federated_language.federated_computation( + [ + federated_language.FederatedType( + state_type, federated_language.SERVER + ) + ], + [[ + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + ]], ) def client_to_server_aggregation(temp_server_state, client_updates): del temp_server_state # Unused - return [intrinsics.federated_sum(client_updates[0][0])] + return [federated_language.federated_sum(client_updates[0][0])] - @federated_computation.federated_computation( - [computation_types.FederatedType(state_type, placements.SERVER)], - [computation_types.FederatedType(np.int32, placements.SERVER)], + @federated_language.federated_computation( + [ + federated_language.FederatedType( + state_type, federated_language.SERVER + ) + ], + [federated_language.FederatedType(np.int32, federated_language.SERVER)], ) def server_result(temp_server_state, aggregated_results): return temp_server_state[0], aggregated_results[0] @@ -701,8 +745,8 @@ def server_result(temp_server_state, aggregated_results): def test_init_raises_type_error_with_bad_server_prepare_parameter_type(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.SERVER) ) def server_prepare(server_state): del server_state # Unused @@ -715,16 +759,14 @@ def server_prepare_broadcast_tf(): def server_prepare_state_tf(): return 32 - return [ - [ - intrinsics.federated_value( - server_prepare_broadcast_tf(), placements.SERVER - ) - ] - ], [ + return [[ + federated_language.federated_value( + server_prepare_broadcast_tf(), federated_language.SERVER + ) + ]], [ server_prepare_state_tf(), - intrinsics.federated_value( - server_prepare_state_tf(), placements.SERVER + federated_language.federated_value( + server_prepare_state_tf(), federated_language.SERVER ), ] @@ -737,8 +779,8 @@ def server_prepare_state_tf(): def test_init_raises_type_error_with_broadcast_input_type_mismatch(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def server_prepare(server_state): @tensorflow_computation.tf_computation @@ -757,11 +799,16 @@ def server_prepare_state_tf(): def test_init_raises_type_error_with_broadcast_output_type_mismatch(self): - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.SequenceType(np.float32), placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.SequenceType(np.float32), + federated_language.CLIENTS, ), - [computation_types.FederatedType(np.int32, placements.CLIENTS)], + [ + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + ], ) def client_work(client_data, context_at_clients): @tensorflow_computation.tf_computation @@ -770,9 +817,11 @@ def client_work_tf(): del client_data # Unused del context_at_clients # Unused - return [ - [intrinsics.federated_value(client_work_tf(), placements.CLIENTS)] - ] + return [[ + federated_language.federated_value( + client_work_tf(), federated_language.CLIENTS + ) + ]] with self.assertRaisesRegex( TypeError, 'The `client_work` computation expects an argument type' @@ -786,12 +835,12 @@ def test_init_raises_assertion_error_with_bad_broadcast_body(self): def multiply_tf(x): return x * 2.0 - @federated_computation.federated_computation( - [[computation_types.FederatedType(np.float32, placements.SERVER)]] - ) + @federated_language.federated_computation([[ + federated_language.FederatedType(np.float32, federated_language.SERVER) + ]]) def server_to_client_broadcast(context_at_server): - a = intrinsics.federated_map(multiply_tf, context_at_server[0][0]) - return [intrinsics.federated_broadcast(a)] + a = federated_language.federated_map(multiply_tf, context_at_server[0][0]) + return [federated_language.federated_broadcast(a)] with self.assertRaisesRegex( ValueError, 'Expected only broadcast intrinsics' @@ -802,17 +851,25 @@ def server_to_client_broadcast(context_at_server): def test_init_raises_type_error_with_aggregation_input_type_mismatch(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( [ np.int32, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ], - [[computation_types.FederatedType(np.int32, placements.CLIENTS)]], + [[ + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + ]], ) def client_to_server_aggregation(temp_server_state, client_updates): del temp_server_state # Unused. return [ - intrinsics.federated_secure_sum_bitwidth(client_updates[0][0], 100) + federated_language.federated_secure_sum_bitwidth( + client_updates[0][0], 100 + ) ] with self.assertRaisesRegex( @@ -828,21 +885,24 @@ def client_to_server_aggregation(temp_server_state, client_updates): def test_init_raises_type_error_with_aggregation_output_type_mismatch(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( [ np.int32, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ], [[ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.CLIENTS + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.CLIENTS, ) ]], ) def client_to_server_aggregation(temp_server_state, client_updates): del temp_server_state # Unused - a = intrinsics.federated_sum(client_updates[0][0]) - b = intrinsics.federated_sum(client_updates[0][0]) + a = federated_language.federated_sum(client_updates[0][0]) + b = federated_language.federated_sum(client_updates[0][0]) return [a, b] with self.assertRaisesRegex( @@ -854,41 +914,50 @@ def client_to_server_aggregation(temp_server_state, client_updates): def test_init_raises_assertion_error_with_bad_aggregation_body(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( [ np.int32, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ], [[ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.CLIENTS + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.CLIENTS, ) ]], ) def client_to_server_aggregation(temp_server_state, client_updates): del temp_server_state # Unused. - a = intrinsics.federated_secure_sum_bitwidth(client_updates[0][0], 100) - b = intrinsics.federated_sum(client_updates[0][0]) + a = federated_language.federated_secure_sum_bitwidth( + client_updates[0][0], 100 + ) + b = federated_language.federated_sum(client_updates[0][0]) return [b, a] - @federated_computation.federated_computation( + @federated_language.federated_computation( [ np.int32, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ], [ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.SERVER + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.SERVER, ), - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.SERVER + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.SERVER, ), ], ) def server_result(temp_server_state, aggregated_results): del aggregated_results # Unused - return temp_server_state[1], intrinsics.federated_value( - [], placements.SERVER + return temp_server_state[1], federated_language.federated_value( + [], federated_language.SERVER ) with self.assertRaisesRegex( @@ -901,40 +970,48 @@ def server_result(temp_server_state, aggregated_results): def test_init_raises_type_error_with_temporary_state_type_mismatch(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( [ np.int32, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), np.int32, ], [[ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.CLIENTS + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.CLIENTS, ) ]], ) def client_to_server_aggregation(temp_server_state, client_updates): del temp_server_state # Unused. return [ - intrinsics.federated_secure_sum_bitwidth(client_updates[0][0], 100) + federated_language.federated_secure_sum_bitwidth( + client_updates[0][0], 100 + ) ] - @federated_computation.federated_computation( + @federated_language.federated_computation( [ np.int32, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), np.int32, ], [ - computation_types.FederatedType( - computation_types.TensorType(np.int32, [2]), placements.SERVER + federated_language.FederatedType( + federated_language.TensorType(np.int32, [2]), + federated_language.SERVER, ) ], ) def server_result(temp_server_state, aggregated_results): del aggregated_results # Unused - return temp_server_state[1], intrinsics.federated_value( - [], placements.SERVER + return temp_server_state[1], federated_language.federated_value( + [], federated_language.SERVER ) with self.assertRaisesRegex( @@ -965,30 +1042,30 @@ def test_init_raises_type_error_with_type_signature_mismatch(self): ) ) - bad_server_state_parameter = computation_types.StructType([ - computation_types.FederatedType(np.float32, placements.SERVER), + bad_server_state_parameter = federated_language.StructType([ + federated_language.FederatedType(np.float32, federated_language.SERVER), test_client_work.type_signature.parameter[0], ]) - bad_client_data_parameter = computation_types.StructType([ + bad_client_data_parameter = federated_language.StructType([ test_server_result.type_signature.parameter[0], - computation_types.FederatedType(np.str_, placements.CLIENTS), + federated_language.FederatedType(np.str_, federated_language.CLIENTS), ]) - bad_server_state_result = computation_types.StructType([ - computation_types.FederatedType(np.float32, placements.SERVER), + bad_server_state_result = federated_language.StructType([ + federated_language.FederatedType(np.float32, federated_language.SERVER), test_server_result.type_signature.result[1], ]) - bad_server_output_result = computation_types.StructType([ + bad_server_output_result = federated_language.StructType([ test_server_result.type_signature.parameter[0], - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType(np.float32, federated_language.SERVER), ]) with self.assertRaisesRegex( TypeError, 'The original computation argument type' ): _build_test_distribute_aggregate_form_with_computations( - type_signature=computation_types.FunctionType( + type_signature=federated_language.FunctionType( bad_server_state_parameter, correct_type_signature.result ) ) @@ -997,7 +1074,7 @@ def test_init_raises_type_error_with_type_signature_mismatch(self): TypeError, 'The original computation argument type' ): _build_test_distribute_aggregate_form_with_computations( - type_signature=computation_types.FunctionType( + type_signature=federated_language.FunctionType( bad_client_data_parameter, correct_type_signature.result ) ) @@ -1006,7 +1083,7 @@ def test_init_raises_type_error_with_type_signature_mismatch(self): TypeError, 'the original computation result type' ): _build_test_distribute_aggregate_form_with_computations( - type_signature=computation_types.FunctionType( + type_signature=federated_language.FunctionType( correct_type_signature.parameter, bad_server_state_result ) ) @@ -1015,7 +1092,7 @@ def test_init_raises_type_error_with_type_signature_mismatch(self): TypeError, 'the original computation result type' ): _build_test_distribute_aggregate_form_with_computations( - type_signature=computation_types.FunctionType( + type_signature=federated_language.FunctionType( correct_type_signature.parameter, bad_server_output_result ) ) diff --git a/tensorflow_federated/python/core/backends/mapreduce/intrinsics.py b/tensorflow_federated/python/core/backends/mapreduce/intrinsics.py index 401dde99b2..bb2aac1eb8 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/intrinsics.py +++ b/tensorflow_federated/python/core/backends/mapreduce/intrinsics.py @@ -13,82 +13,72 @@ # limitations under the License. """Intrinsics for the mapreduce backend.""" +import federated_language import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import symbol_binding_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.federated_context import value_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_conversions # Computes the modular sum of client values on the server, securely. Only # supported for integers or nested structures of integers. # # Type signature: <{V}@CLIENTS,M> -> V@SERVER -FEDERATED_SECURE_MODULAR_SUM = intrinsic_defs.IntrinsicDef( +FEDERATED_SECURE_MODULAR_SUM = federated_language.framework.IntrinsicDef( 'FEDERATED_SECURE_MODULAR_SUM', 'federated_secure_modular_sum', - computation_types.FunctionType( + federated_language.FunctionType( parameter=[ - computation_types.FederatedType( - computation_types.AbstractType('V'), placements.CLIENTS + federated_language.FederatedType( + federated_language.AbstractType('V'), federated_language.CLIENTS ), - computation_types.AbstractType('M'), + federated_language.AbstractType('M'), ], - result=computation_types.FederatedType( - computation_types.AbstractType('V'), placements.SERVER + result=federated_language.FederatedType( + federated_language.AbstractType('V'), federated_language.SERVER ), ), - aggregation_kind=intrinsic_defs.AggregationKind.SECURE, + aggregation_kind=federated_language.framework.AggregationKind.SECURE, ) def _cast( - comp: building_blocks.ComputationBuildingBlock, - type_signature: computation_types.TensorType, -) -> building_blocks.Call: + comp: federated_language.framework.ComputationBuildingBlock, + type_signature: federated_language.TensorType, +) -> federated_language.framework.Call: """Casts `comp` to the provided type.""" def cast_fn(value): - def cast_element(element, type_signature: computation_types.TensorType): + def cast_element(element, type_signature: federated_language.TensorType): return tf.cast(element, type_signature.dtype) - if isinstance(comp.type_signature, computation_types.StructType): + if isinstance(comp.type_signature, federated_language.StructType): return structure.map_structure(cast_element, value, type_signature) return cast_element(value, type_signature) cast_proto, cast_type = tensorflow_computation_factory.create_unary_operator( cast_fn, comp.type_signature ) - cast_comp = building_blocks.CompiledComputation( + cast_comp = federated_language.framework.CompiledComputation( cast_proto, type_signature=cast_type ) - return building_blocks.Call(cast_comp, comp) + return federated_language.framework.Call(cast_comp, comp) def create_federated_secure_modular_sum( - value: building_blocks.ComputationBuildingBlock, - modulus: building_blocks.ComputationBuildingBlock, + value: federated_language.framework.ComputationBuildingBlock, + modulus: federated_language.framework.ComputationBuildingBlock, preapply_modulus: bool = True, -) -> building_blocks.ComputationBuildingBlock: +) -> federated_language.framework.ComputationBuildingBlock: r"""Creates a called secure modular sum. Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - modulus: A `building_blocks.ComputationBuildingBlock` to use as the - `modulus` value. + value: A `federated_language.framework.ComputationBuildingBlock` to use as + the value. + modulus: A `federated_language.framework.ComputationBuildingBlock` to use as + the `modulus` value. preapply_modulus: Whether or not to preapply `modulus` to the input `value`. This can be `False` if `value` is guaranteed to already be in range. @@ -98,29 +88,37 @@ def create_federated_secure_modular_sum( Raises: TypeError: If any of the types do not match. """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(modulus, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( + py_typecheck.check_type( + value, federated_language.framework.ComputationBuildingBlock + ) + py_typecheck.check_type( + modulus, federated_language.framework.ComputationBuildingBlock + ) + result_type = federated_language.FederatedType( value.type_signature.member, # pytype: disable=attribute-error - placements.SERVER, + federated_language.SERVER, ) - intrinsic_type = computation_types.FunctionType( + intrinsic_type = federated_language.FunctionType( [ - type_conversions.type_to_non_all_equal(value.type_signature), + federated_language.framework.type_to_non_all_equal( + value.type_signature + ), modulus.type_signature, ], result_type, ) - intrinsic = building_blocks.Intrinsic( + intrinsic = federated_language.framework.Intrinsic( FEDERATED_SECURE_MODULAR_SUM.uri, intrinsic_type ) if not preapply_modulus: - values = building_blocks.Struct([value, modulus]) - return building_blocks.Call(intrinsic, values) + values = federated_language.framework.Struct([value, modulus]) + return federated_language.framework.Call(intrinsic, values) # Pre-insert a modulus to ensure the the input values are within range. - mod_ref = building_blocks.Reference('mod', modulus.type_signature) + mod_ref = federated_language.framework.Reference( + 'mod', modulus.type_signature + ) # In order to run `tf.math.floormod`, our modulus and value must be the same # type. @@ -132,13 +130,13 @@ def create_federated_secure_modular_sum( # at the client as well as at the server for aggregation, we need to broadcast # the modulus to be able to avoid repeating the modulus value (which could # cause accuracy issues if the modulus is non-deterministic). - casted_mod_at_server = building_block_factory.create_federated_value( - casted_mod, placements.SERVER + casted_mod_at_server = federated_language.framework.create_federated_value( + casted_mod, federated_language.SERVER ) - value_with_mod = building_block_factory.create_federated_zip( - building_blocks.Struct([ + value_with_mod = federated_language.framework.create_federated_zip( + federated_language.framework.Struct([ value, - building_block_factory.create_federated_broadcast( + federated_language.framework.create_federated_broadcast( casted_mod_at_server ), ]) @@ -154,32 +152,32 @@ def structural_modulus(value, mod): casted_mod.type_signature, ) ) - structural_modulus_tf = building_blocks.CompiledComputation( + structural_modulus_tf = federated_language.framework.CompiledComputation( structural_modulus_proto, type_signature=structural_modulus_type ) - value_modded = building_block_factory.create_federated_map_or_apply( + value_modded = federated_language.framework.create_federated_map_or_apply( structural_modulus_tf, value_with_mod ) - values = building_blocks.Struct([value_modded, mod_ref]) - return building_blocks.Block( - [('mod', modulus)], building_blocks.Call(intrinsic, values) + values = federated_language.framework.Struct([value_modded, mod_ref]) + return federated_language.framework.Block( + [('mod', modulus)], federated_language.framework.Call(intrinsic, values) ) def create_null_federated_secure_modular_sum(): return create_federated_secure_modular_sum( - building_block_factory.create_federated_value( - building_blocks.Struct([]), placements.CLIENTS + federated_language.framework.create_federated_value( + federated_language.framework.Struct([]), federated_language.CLIENTS ), - building_blocks.Struct([]), + federated_language.framework.Struct([]), preapply_modulus=False, ) def _bind_comp_as_reference(comp): - context = context_stack_impl.context_stack.current - if not isinstance(context, symbol_binding_context.SymbolBindingContext): - raise context_base.ContextError( + context = federated_language.framework.global_context_stack.current + if not isinstance(context, federated_language.framework.SymbolBindingContext): + raise federated_language.framework.ContextError( f'Attempted to construct an intrinsic in context {context} which ' ' does not support binding references.' ) @@ -238,15 +236,17 @@ def federated_secure_modular_sum(value, modulus): TypeError: If the argument is not a federated TFF value placed at `tff.CLIENTS`. """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to be summed' + value = federated_language.to_value(value, type_spec=None) + value = federated_language.framework.ensure_federated_value( + value, federated_language.CLIENTS, 'value to be summed' + ) + federated_language.framework.check_is_structure_of_integers( + value.type_signature ) - type_analysis.check_is_structure_of_integers(value.type_signature) - modulus_value = value_impl.to_value(modulus, type_spec=None) + modulus_value = federated_language.to_value(modulus, type_spec=None) value_member_type = value.type_signature.member # pytype: disable=attribute-error modulus_type = modulus_value.type_signature - if not type_analysis.is_single_integer_or_matches_structure( + if not federated_language.framework.is_single_integer_or_matches_structure( modulus_type, value_member_type ): raise TypeError( @@ -256,13 +256,13 @@ def federated_secure_modular_sum(value, modulus): value_member_type, modulus_type ) ) - if isinstance(modulus_type, computation_types.TensorType) and isinstance( - value_member_type, computation_types.StructType + if isinstance(modulus_type, federated_language.TensorType) and isinstance( + value_member_type, federated_language.StructType ): - modulus_value = value_impl.to_value( + modulus_value = federated_language.to_value( structure.map_structure(lambda _: modulus, value_member_type), type_spec=None, ) comp = create_federated_secure_modular_sum(value.comp, modulus_value.comp) comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) + return federated_language.Value(comp) diff --git a/tensorflow_federated/python/core/backends/mapreduce/intrinsics_test.py b/tensorflow_federated/python/core/backends/mapreduce/intrinsics_test.py index f558dc5c8b..9c8349d479 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/intrinsics_test.py +++ b/tensorflow_federated/python/core/backends/mapreduce/intrinsics_test.py @@ -18,32 +18,26 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.common_libs import golden from tensorflow_federated.python.core.backends.mapreduce import intrinsics -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _create_context() -> ( - federated_computation_context.FederatedComputationContext + federated_language.framework.FederatedComputationContext ): - return federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack + return federated_language.framework.FederatedComputationContext( + federated_language.framework.global_context_stack ) -def _create_fake_value(type_spec: computation_types.Type) -> value_impl.Value: - value = building_blocks.Reference('value', type_spec) - return value_impl.Value(value) +def _create_fake_value( + type_spec: federated_language.Type, +) -> federated_language.Value: + value = federated_language.framework.Reference('value', type_spec) + return federated_language.Value(value) class IntrinsicDefsTest(absltest.TestCase): @@ -61,28 +55,34 @@ class CreateFederatedSecureModularSumTest(absltest.TestCase): def test_raises_type_error_with_none_value(self): modulus = mock.create_autospec( - building_blocks.CompiledComputation, spec_set=True, instance=True + federated_language.framework.CompiledComputation, + spec_set=True, + instance=True, ) with self.assertRaises(TypeError): intrinsics.create_federated_secure_modular_sum(None, modulus) def test_raises_type_error_with_none_modulus(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, + value = federated_language.framework.create_federated_value( + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), + placement=federated_language.CLIENTS, ) with self.assertRaises(TypeError): intrinsics.create_federated_secure_modular_sum(value, None) def test_returns_federated_sum(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, + value = federated_language.framework.create_federated_value( + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), + placement=federated_language.CLIENTS, ) - modulus_type = computation_types.TensorType(np.int32) - modulus = building_blocks.Literal(2, modulus_type) + modulus_type = federated_language.TensorType(np.int32) + modulus = federated_language.framework.Literal(2, modulus_type) comp = intrinsics.create_federated_secure_modular_sum(value, modulus) # Regex replaces compiled computations such as `comp#b03f` to ensure a # consistent output. @@ -103,102 +103,116 @@ class FederatedSecureModularSumTest(parameterized.TestCase): ( 'value_int_clients_and_modulus_int', _create_fake_value( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) ), - _create_fake_value(computation_types.TensorType(np.int32)), + _create_fake_value(federated_language.TensorType(np.int32)), ), ( 'value_struct_int_clients_and_modulus_int', _create_fake_value( - computation_types.FederatedType( - [np.int32, np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32, np.int32], federated_language.CLIENTS ) ), - _create_fake_value(computation_types.TensorType(np.int32)), + _create_fake_value(federated_language.TensorType(np.int32)), ), ( 'value_struct_int_clients_and_modulus_struct', _create_fake_value( - computation_types.FederatedType( - [np.int32, np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32, np.int32], federated_language.CLIENTS ) ), _create_fake_value( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [np.int32, np.int32, np.int32], list ) ), ), ) - @context_stack_test_utils.with_context(_create_context) + @federated_language.framework.with_context(_create_context) def test_returns_result(self, value, modulus): result = intrinsics.federated_secure_modular_sum(value, modulus) - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER + expected_type = federated_language.FederatedType( + value.type_signature.member, federated_language.SERVER ) self.assertEqual(result.type_signature, expected_type) @parameterized.named_parameters( ( 'value_int_unplaced', - _create_fake_value(computation_types.TensorType(np.int32)), - _create_fake_value(computation_types.TensorType(np.int32)), + _create_fake_value(federated_language.TensorType(np.int32)), + _create_fake_value(federated_language.TensorType(np.int32)), ), ( 'value_float_clients', _create_fake_value( - computation_types.FederatedType(np.float32, placements.CLIENTS) + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ) ), - _create_fake_value(computation_types.TensorType(np.int32)), + _create_fake_value(federated_language.TensorType(np.int32)), ), ( 'value_int_server', _create_fake_value( - computation_types.FederatedType(np.int32, placements.SERVER) + federated_language.FederatedType( + np.int32, federated_language.SERVER + ) ), - _create_fake_value(computation_types.TensorType(np.int32)), + _create_fake_value(federated_language.TensorType(np.int32)), ), ( 'modulus_int_clients', _create_fake_value( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) ), _create_fake_value( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) ), ), ( 'modulus_int_server', _create_fake_value( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) ), _create_fake_value( - computation_types.FederatedType(np.int32, placements.SERVER) + federated_language.FederatedType( + np.int32, federated_language.SERVER + ) ), ), ( 'mismatched_structures', _create_fake_value( - computation_types.FederatedType( - [np.int32] * 2, placements.CLIENTS + federated_language.FederatedType( + [np.int32] * 2, federated_language.CLIENTS ), ), - _create_fake_value(computation_types.StructType([np.int32] * 3)), + _create_fake_value(federated_language.StructType([np.int32] * 3)), ), ) - @context_stack_test_utils.with_context(_create_context) + @federated_language.framework.with_context(_create_context) def test_raises_type_error(self, value, modulus): with self.assertRaises(TypeError): intrinsics.federated_secure_modular_sum(value, modulus) def test_raises_context_error_with_no_federated_context(self): value = _create_fake_value( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) - modulus = _create_fake_value(computation_types.TensorType(np.int32)) + modulus = _create_fake_value(federated_language.TensorType(np.int32)) - with self.assertRaises(context_base.ContextError): + with self.assertRaises(federated_language.framework.ContextError): intrinsics.federated_secure_modular_sum(value, modulus) diff --git a/tensorflow_federated/python/core/backends/mapreduce/mapreduce_test_utils.py b/tensorflow_federated/python/core/backends/mapreduce/mapreduce_test_utils.py index d6f18159e6..1b0fd19d12 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/mapreduce_test_utils.py +++ b/tensorflow_federated/python/core/backends/mapreduce/mapreduce_test_utils.py @@ -15,20 +15,13 @@ import collections +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.mapreduce import forms from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process MapReduceFormExample = collections.namedtuple( @@ -38,61 +31,63 @@ def generate_unnamed_type_signature(update, work): """Generates a type signature for the MapReduceForm based on components.""" - parameter = computation_types.StructType([ + parameter = federated_language.StructType([ ( None, - computation_types.FederatedType( - update.type_signature.parameter[0], placements.SERVER + federated_language.FederatedType( + update.type_signature.parameter[0], federated_language.SERVER ), ), ( None, - computation_types.FederatedType( - work.type_signature.parameter[0], placements.CLIENTS + federated_language.FederatedType( + work.type_signature.parameter[0], federated_language.CLIENTS ), ), ]) - result = computation_types.StructType([ + result = federated_language.StructType([ ( None, - computation_types.FederatedType( - update.type_signature.parameter[0], placements.SERVER + federated_language.FederatedType( + update.type_signature.parameter[0], federated_language.SERVER ), ), ( None, - computation_types.FederatedType( - update.type_signature.result[1], placements.SERVER + federated_language.FederatedType( + update.type_signature.result[1], federated_language.SERVER ), ), ]) - return computation_types.FunctionType(parameter, result) + return federated_language.FunctionType(parameter, result) def _make_map_reduce_form_example( - initialize: computation_impl.ConcreteComputation, - type_signature: computation_types.FunctionType, - prepare: computation_impl.ConcreteComputation, - work: computation_impl.ConcreteComputation, - zero: computation_impl.ConcreteComputation, - accumulate: computation_impl.ConcreteComputation, - merge: computation_impl.ConcreteComputation, - report: computation_impl.ConcreteComputation, - secure_sum_bitwidth: computation_impl.ConcreteComputation, - secure_sum_max_input: computation_impl.ConcreteComputation, - secure_sum_modulus: computation_impl.ConcreteComputation, - update: computation_impl.ConcreteComputation, + initialize: federated_language.framework.ConcreteComputation, + type_signature: federated_language.FunctionType, + prepare: federated_language.framework.ConcreteComputation, + work: federated_language.framework.ConcreteComputation, + zero: federated_language.framework.ConcreteComputation, + accumulate: federated_language.framework.ConcreteComputation, + merge: federated_language.framework.ConcreteComputation, + report: federated_language.framework.ConcreteComputation, + secure_sum_bitwidth: federated_language.framework.ConcreteComputation, + secure_sum_max_input: federated_language.framework.ConcreteComputation, + secure_sum_modulus: federated_language.framework.ConcreteComputation, + update: federated_language.framework.ConcreteComputation, ) -> MapReduceFormExample: """Constructs a MapReduceFormExample given the component comps.""" - def _uniquify_reference_names(comp: computation_impl.ConcreteComputation): + def _uniquify_reference_names( + comp: federated_language.framework.ConcreteComputation, + ): building_block = comp.to_building_block() transformed_comp = tree_transformations.uniquify_reference_names( building_block )[0] - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=transformed_comp.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) return MapReduceFormExample( @@ -121,17 +116,20 @@ def get_temperature_sensor_example(): Returns: A tuple of: (1) an instance of `forms.MapReduceForm` and (2) an associated - `computation_base.Computation` that generates an initial state compatible + `federated_language.framework.Computation` that generates an initial state + compatible with the server state expected by the `forms.MapReduceForm`. """ - @federated_computation.federated_computation() + @federated_language.federated_computation() def initialize(): @tensorflow_computation.tf_computation def initialize_tf(): return collections.OrderedDict(num_rounds=0) - return intrinsics.federated_value(initialize_tf(), placements.SERVER) + return federated_language.federated_value( + initialize_tf(), federated_language.SERVER + ) # The state of the server is a singleton tuple containing just the integer # counter `num_rounds`. @@ -148,7 +146,7 @@ def prepare(state): client_state_type = collections.OrderedDict(max_temperature=np.float32) # The client data is a sequence of floats. - client_data_type = computation_types.SequenceType(np.float32) + client_data_type = federated_language.SequenceType(np.float32) @tensorflow_computation.tf_computation(client_data_type, client_state_type) def work(data, state): @@ -243,7 +241,7 @@ def update(state, update): def get_federated_sum_example( *, secure_sum: bool = False -) -> tuple[forms.MapReduceForm, computation_base.Computation]: +) -> tuple[forms.MapReduceForm, federated_language.framework.Computation]: """Constructs `forms.MapReduceForm` which performs a sum aggregation. Args: @@ -258,9 +256,11 @@ def get_federated_sum_example( def initialize_tf(): return () - @federated_computation.federated_computation() + @federated_language.federated_computation() def initialize(): - return intrinsics.federated_value(initialize_tf(), placements.SERVER) + return federated_language.federated_value( + initialize_tf(), federated_language.SERVER + ) server_state_type = initialize_tf.type_signature.result @@ -269,7 +269,7 @@ def prepare(state): return state @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.int32), prepare.type_signature.result + federated_language.SequenceType(np.int32), prepare.type_signature.result ) def work(data, _): client_sum = data.reduce(initial_state=0, reduce_func=tf.add) @@ -351,7 +351,7 @@ def get_mnist_training_example(): server_state_nt = collections.namedtuple('ServerState', 'model num_rounds') # Start with a model filled with zeros, and the round counter set to zero. - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize(): @tensorflow_computation.tf_computation def initialize_tf(): @@ -363,7 +363,9 @@ def initialize_tf(): num_rounds=0, ) - return intrinsics.federated_value(initialize_tf(), placements.SERVER) + return federated_language.federated_value( + initialize_tf(), federated_language.SERVER + ) server_state_tff_type = server_state_nt( model=model_nt(weights=(np.float32, [784, 10]), bias=(np.float32, [10])), @@ -380,7 +382,7 @@ def _prepare(state): batch_nt = collections.namedtuple('Batch', 'x y') batch_tff_type = batch_nt(x=(np.float32, [None, 784]), y=(np.int32, [None])) - dataset_tff_type = computation_types.SequenceType(batch_tff_type) + dataset_tff_type = federated_language.SequenceType(batch_tff_type) model_tff_type = model_nt( weights=(np.float32, [784, 10]), bias=(np.float32, [10]) ) @@ -550,43 +552,49 @@ def get_iterative_process_for_example_with_unused_lambda_arg(): server_state_type = collections.OrderedDict(num_clients=np.int32) def _bind_federated_value(unused_input, input_type, federated_output_value): - federated_input_type = computation_types.FederatedType( - input_type, placements.CLIENTS + federated_input_type = federated_language.FederatedType( + input_type, federated_language.CLIENTS ) - wrapper = federated_computation.federated_computation( + wrapper = federated_language.federated_computation( lambda _: federated_output_value, federated_input_type ) return wrapper(unused_input) def count_clients_federated(client_data): - client_ones = intrinsics.federated_value(1, placements.CLIENTS) + client_ones = federated_language.federated_value( + 1, federated_language.CLIENTS + ) client_ones = _bind_federated_value( - client_data, computation_types.SequenceType(np.str_), client_ones + client_data, federated_language.SequenceType(np.str_), client_ones ) - return intrinsics.federated_sum(client_ones) + return federated_language.federated_sum(client_ones) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value( - collections.OrderedDict(num_clients=0), placements.SERVER + return federated_language.federated_value( + collections.OrderedDict(num_clients=0), federated_language.SERVER ) - @federated_computation.federated_computation([ - computation_types.FederatedType(server_state_type, placements.SERVER), - computation_types.FederatedType( - computation_types.SequenceType(np.str_), placements.CLIENTS + @federated_language.federated_computation([ + federated_language.FederatedType( + server_state_type, federated_language.SERVER + ), + federated_language.FederatedType( + federated_language.SequenceType(np.str_), federated_language.CLIENTS ), ]) def next_fn(server_state, client_val): """`next` function for `tff.templates.IterativeProcess`.""" - server_update = intrinsics.federated_zip( + server_update = federated_language.federated_zip( collections.OrderedDict(num_clients=count_clients_federated(client_val)) ) - server_output = intrinsics.federated_value((), placements.SERVER) + server_output = federated_language.federated_value( + (), federated_language.SERVER + ) server_output = _bind_federated_value( - intrinsics.federated_broadcast(server_state), + federated_language.federated_broadcast(server_state), server_state_type, server_output, ) @@ -607,7 +615,7 @@ def _bind_tf_function(unused_input, tf_func): tf_wrapper, input_federated_type.member, # pytype: disable=attribute-error ) - return intrinsics.federated_map(wrapper, unused_input) + return federated_language.federated_map(wrapper, unused_input) def count_clients_federated(client_data): @tf.function @@ -615,29 +623,31 @@ def client_ones_fn(): return np.ones(shape=[], dtype=np.int32) client_ones = _bind_tf_function(client_data, client_ones_fn) - return intrinsics.federated_sum(client_ones) + return federated_language.federated_sum(client_ones) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value( - collections.OrderedDict(num_clients=0), placements.SERVER + return federated_language.federated_value( + collections.OrderedDict(num_clients=0), federated_language.SERVER ) - @federated_computation.federated_computation([ - computation_types.FederatedType(server_state_type, placements.SERVER), - computation_types.FederatedType( - computation_types.SequenceType(np.str_), placements.CLIENTS + @federated_language.federated_computation([ + federated_language.FederatedType( + server_state_type, federated_language.SERVER + ), + federated_language.FederatedType( + federated_language.SequenceType(np.str_), federated_language.CLIENTS ), ]) def next_fn(server_state, client_val): """`next` function for `tff.templates.IterativeProcess`.""" - server_update = intrinsics.federated_zip( + server_update = federated_language.federated_zip( collections.OrderedDict(num_clients=count_clients_federated(client_val)) ) - server_output = intrinsics.federated_sum( + server_output = federated_language.federated_sum( _bind_tf_function( - intrinsics.federated_broadcast(server_state), tf.timestamp + federated_language.federated_broadcast(server_state), tf.timestamp ) ) @@ -649,22 +659,22 @@ def next_fn(server_state, client_val): def get_iterative_process_for_example_with_lambda_returning_aggregation(): """Gets iterative process with indirection to the called intrinsic.""" server_state_type = collections.OrderedDict(num_clients=np.int32) - client_val_type = computation_types.FederatedType( - server_state_type, placements.CLIENTS + client_val_type = federated_language.FederatedType( + server_state_type, federated_language.CLIENTS ) - @federated_computation.federated_computation + @federated_language.federated_computation def computation_returning_lambda(): - @federated_computation.federated_computation(np.int32) + @federated_language.federated_computation(np.int32) def computation_returning_sum(x): tuple_containing_intrinsic = [ - building_blocks.Intrinsic( + federated_language.framework.Intrinsic( 'federated_sum', - computation_types.FunctionType( + federated_language.FunctionType( client_val_type, - computation_types.FederatedType( - client_val_type.member, placements.SERVER + federated_language.FederatedType( + client_val_type.member, federated_language.SERVER ), ), ), @@ -675,20 +685,22 @@ def computation_returning_sum(x): return computation_returning_sum - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value( - collections.OrderedDict(num_clients=0), placements.SERVER + return federated_language.federated_value( + collections.OrderedDict(num_clients=0), federated_language.SERVER ) - @federated_computation.federated_computation([ - computation_types.FederatedType(server_state_type, placements.SERVER), + @federated_language.federated_computation([ + federated_language.FederatedType( + server_state_type, federated_language.SERVER + ), client_val_type, ]) def next_fn(server_state, client_val): """`next` function for `tff.templates.IterativeProcess`.""" - server_update = intrinsics.federated_sum(client_val) - state_at_clients = intrinsics.federated_broadcast(server_state) + server_update = federated_language.federated_sum(client_val) + state_at_clients = federated_language.federated_broadcast(server_state) lambda_returning_sum = computation_returning_lambda() sum_fn = lambda_returning_sum(1) server_output = sum_fn(state_at_clients) diff --git a/tensorflow_federated/python/core/backends/native/BUILD b/tensorflow_federated/python/core/backends/native/BUILD index 9b5e6bcfca..074be2406c 100644 --- a/tensorflow_federated/python/core/backends/native/BUILD +++ b/tensorflow_federated/python/core/backends/native/BUILD @@ -34,14 +34,11 @@ py_library( name = "compiler", srcs = ["compiler.py"], deps = [ - "//tensorflow_federated/python/common_libs:tracing", "//tensorflow_federated/python/core/backends/mapreduce:compiler", "//tensorflow_federated/python/core/environments/tensorflow_backend:compiled_computation_transformations", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", "//tensorflow_federated/python/core/impl/compiler:transformations", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", + "@federated_language//federated_language", ], ) @@ -50,12 +47,7 @@ py_test( srcs = ["compiler_test.py"], deps = [ ":compiler", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -69,11 +61,9 @@ py_library( ":compiler", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_executor_bindings", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/context_stack:set_default_context", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:cpp_executor_factory", "//tensorflow_federated/python/core/impl/executors:executor_bindings", + "@federated_language//federated_language", ], ) @@ -87,14 +77,8 @@ py_test( ":cpp_execution_contexts", ":execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/context_stack:get_context_stack", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executors:executor_bindings", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -105,12 +89,9 @@ py_library( ":compiler", ":mergeable_comp_compiler", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", "//tensorflow_federated/python/core/impl/execution_contexts:mergeable_comp_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:executor_factory", + "@federated_language//federated_language", ], ) @@ -122,8 +103,8 @@ py_test( "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_impl", "//tensorflow_federated/python/core/impl/executors:remote_executor_grpc_stub", + "@federated_language//federated_language", ], ) @@ -135,18 +116,10 @@ py_library( "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_building_block_factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_computation_factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", "//tensorflow_federated/python/core/impl/compiler:transformations", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", "//tensorflow_federated/python/core/impl/compiler:tree_transformations", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", "//tensorflow_federated/python/core/impl/execution_contexts:mergeable_comp_execution_context", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -156,13 +129,8 @@ py_test( deps = [ ":mergeable_comp_compiler", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", "//tensorflow_federated/python/core/impl/execution_contexts:mergeable_comp_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:executor_factory", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/backends/native/compiler.py b/tensorflow_federated/python/core/backends/native/compiler.py index 64436afee4..2a0732a541 100644 --- a/tensorflow_federated/python/core/backends/native/compiler.py +++ b/tensorflow_federated/python/core/backends/native/compiler.py @@ -16,23 +16,20 @@ from typing import Optional from absl import logging +import federated_language import tensorflow as tf -from tensorflow_federated.python.common_libs import tracing from tensorflow_federated.python.core.backends.mapreduce import compiler from tensorflow_federated.python.core.environments.tensorflow_backend import compiled_computation_transformations from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations -from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import transformations -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl def transform_to_native_form( - comp: computation_impl.ConcreteComputation, + comp: federated_language.framework.ConcreteComputation, transform_math_to_tf: bool = False, grappler_config: Optional[tf.compat.v1.ConfigProto] = None, -) -> computation_impl.ConcreteComputation: +) -> federated_language.framework.ConcreteComputation: """Compiles a computation for execution in the TFF native runtime. This function transforms the proto underlying `comp` by transforming it @@ -40,7 +37,8 @@ def transform_to_native_form( definition). Args: - comp: Instance of `computation_impl.ConcreteComputation` to compile. + comp: Instance of `federated_language.framework.ConcreteComputation` to + compile. transform_math_to_tf: Whether to additional transform math to TensorFlow graphs. Necessary if running on a execution state without ReferenceResolvingExecutors underneath FederatingExecutors. @@ -49,16 +47,17 @@ def transform_to_native_form( optimizations wil be applied. Returns: - A new `computation_impl.ConcreteComputation` representing the compiled + A new `federated_language.framework.ConcreteComputation` representing the + compiled version of `comp`. """ - proto = computation_impl.ConcreteComputation.get_proto(comp) + proto = federated_language.framework.ConcreteComputation.get_proto(comp) computation_building_block = ( - building_blocks.ComputationBuildingBlock.from_proto(proto) + federated_language.framework.ComputationBuildingBlock.from_proto(proto) ) try: logging.debug('Compiling TFF computation to CDF.') - with tracing.span( + with federated_language.framework.span( 'transform_to_native_form', 'to_call_dominant', span=True ): call_dominant_form = transformations.to_call_dominant( @@ -68,7 +67,7 @@ def transform_to_native_form( logging.debug(call_dominant_form.formatted_representation()) if transform_math_to_tf: logging.debug('Compiling local computations to TensorFlow.') - with tracing.span( + with federated_language.framework.span( 'transform_to_native_form', 'compile_local_subcomputations_to_tensorflow', span=True, @@ -81,7 +80,7 @@ def transform_to_native_form( logging.debug('Computation compiled to:') logging.debug(call_dominant_form.formatted_representation()) if grappler_config is not None: - with tracing.span( + with federated_language.framework.span( 'transform_to_native_form', 'optimize_tf_graphs', span=True ): call_dominant_form, _ = ( @@ -89,7 +88,7 @@ def transform_to_native_form( call_dominant_form, grappler_config ) ) - with tracing.span( + with federated_language.framework.span( 'transform_to_native_form', 'transform_tf_call_ops_disable_grappler', span=True, @@ -99,7 +98,7 @@ def transform_to_native_form( call_dominant_form ) ) - with tracing.span( + with federated_language.framework.span( 'transform_to_native_form', 'transform_tf_add_ids', span=True ): form_with_ids, _ = ( @@ -107,9 +106,9 @@ def transform_to_native_form( disabled_grapler_form ) ) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=form_with_ids.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) except ValueError as e: logging.debug('Compilation for native runtime failed with error %s', e) @@ -142,9 +141,9 @@ def desugar_and_transform_to_native(comp): # adds TF cache IDs to them. It is crucial that these transformations execute # in this order. native_form = transform_to_native_form( - computation_impl.ConcreteComputation( + federated_language.framework.ConcreteComputation( computation_proto=intrinsics_desugared_bb.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ), grappler_config=grappler_config, ) diff --git a/tensorflow_federated/python/core/backends/native/compiler_test.py b/tensorflow_federated/python/core/backends/native/compiler_test.py index 6d081d1411..9eaeda8d63 100644 --- a/tensorflow_federated/python/core/backends/native/compiler_test.py +++ b/tensorflow_federated/python/core/backends/native/compiler_test.py @@ -13,33 +13,28 @@ # limitations under the License. from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.backends.native import compiler -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements class DesugarAndTransformTest(absltest.TestCase): def test_desugaring_sum_insert_id_for_tf_computations(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def fed_sum(x): - return intrinsics.federated_sum(x) + return federated_language.federated_sum(x) reduced_comp = compiler.desugar_and_transform_to_native(fed_sum) reduced_bb = reduced_comp.to_building_block() def _check_tf_computations_have_ids(comp): if ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'tensorflow' and not comp.proto.tensorflow.cache_key.id ): @@ -50,7 +45,7 @@ def _check_tf_computations_have_ids(comp): return comp, False # Doesn't raise. - transformation_utils.transform_postorder( + federated_language.framework.transform_postorder( reduced_bb, _check_tf_computations_have_ids ) diff --git a/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py index d61dc3c46e..d2ab1f0282 100644 --- a/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/native/cpp_execution_contexts.py @@ -15,12 +15,11 @@ from collections.abc import Sequence +import federated_language + from tensorflow_federated.python.core.backends.native import compiler from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_executor_bindings from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.context_stack import set_default_context -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executor_stacks import cpp_executor_factory from tensorflow_federated.python.core.impl.executors import executor_bindings @@ -44,7 +43,7 @@ def create_sync_local_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> sync_execution_context.SyncExecutionContext: +) -> federated_language.framework.SyncExecutionContext: """Creates a local execution context backed by TFF-C++ runtime. Args: @@ -65,7 +64,7 @@ def create_sync_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) - context = sync_execution_context.SyncExecutionContext( + context = federated_language.framework.SyncExecutionContext( executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, @@ -95,14 +94,14 @@ def set_sync_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - set_default_context.set_default_context(context) + federated_language.framework.set_default_context(context) def create_async_local_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> async_execution_context.AsyncExecutionContext: +) -> federated_language.framework.AsyncExecutionContext: """Creates a local async execution context backed by TFF-C++ runtime. Args: @@ -115,7 +114,8 @@ def create_async_local_cpp_execution_context( stream_structs: The flag to enable decomposing and streaming struct values. Returns: - An instance of `context_base.AsyncContext` representing the TFF-C++ runtime. + An instance of `federated_language.framework.AsyncContext` representing the + TFF-C++ runtime. """ del stream_structs # Unused. factory = cpp_executor_factory.local_cpp_executor_factory( @@ -123,7 +123,7 @@ def create_async_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) - context = async_execution_context.AsyncExecutionContext( + context = federated_language.framework.AsyncExecutionContext( executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, @@ -153,18 +153,18 @@ def set_async_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - set_default_context.set_default_context(context) + federated_language.framework.set_default_context(context) def create_sync_remote_cpp_execution_context( channels: Sequence[executor_bindings.GRPCChannel], default_num_clients: int = 0, -) -> sync_execution_context.SyncExecutionContext: +) -> federated_language.framework.SyncExecutionContext: """Creates a remote execution context backed by TFF-C++ runtime.""" factory = cpp_executor_factory.remote_cpp_executor_factory( channels=channels, default_num_clients=default_num_clients ) - context = sync_execution_context.SyncExecutionContext( + context = federated_language.framework.SyncExecutionContext( executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, @@ -180,18 +180,18 @@ def set_sync_remote_cpp_execution_context( context = create_sync_remote_cpp_execution_context( channels=channels, default_num_clients=default_num_clients ) - set_default_context.set_default_context(context) + federated_language.framework.set_default_context(context) def create_async_remote_cpp_execution_context( channels: Sequence[executor_bindings.GRPCChannel], default_num_clients: int = 0, -) -> async_execution_context.AsyncExecutionContext: +) -> federated_language.framework.AsyncExecutionContext: """Creates a remote execution context backed by TFF-C++ runtime.""" factory = cpp_executor_factory.remote_cpp_executor_factory( channels=channels, default_num_clients=default_num_clients ) - context = async_execution_context.AsyncExecutionContext( + context = federated_language.framework.AsyncExecutionContext( executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, diff --git a/tensorflow_federated/python/core/backends/native/cpp_execution_contexts_test.py b/tensorflow_federated/python/core/backends/native/cpp_execution_contexts_test.py index cd295ab66d..58e58a7e66 100644 --- a/tensorflow_federated/python/core/backends/native/cpp_execution_contexts_test.py +++ b/tensorflow_federated/python/core/backends/native/cpp_execution_contexts_test.py @@ -18,20 +18,14 @@ import time from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import cpp_execution_contexts from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.context_stack import get_context_stack -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executors import executor_bindings -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _assert_signature_equal(first_obj, second_obj): @@ -61,7 +55,7 @@ def test_has_same_signature(self): def test_returns_async_context(self): context = cpp_execution_contexts.create_async_local_cpp_execution_context() self.assertIsInstance( - context, async_execution_context.AsyncExecutionContext + context, federated_language.framework.AsyncExecutionContext ) def test_install_and_execute_in_context(self): @@ -71,7 +65,7 @@ def test_install_and_execute_in_context(self): def add_one(x): return x + 1 - with get_context_stack.get_context_stack().install(context): + with federated_language.framework.get_context_stack().install(context): val_coro = add_one(1) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 2) @@ -79,13 +73,13 @@ def add_one(x): def test_install_and_execute_computations_with_different_cardinalities(self): context = cpp_execution_contexts.create_async_local_cpp_execution_context() - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def repackage_arg(x): return [x, x] - with get_context_stack.get_context_stack().install(context): + with federated_language.framework.get_context_stack().install(context): single_val_coro = repackage_arg([1]) second_val_coro = repackage_arg([1, 2]) self.assertTrue(asyncio.iscoroutine(single_val_coro)) @@ -115,7 +109,9 @@ def test_has_same_signature(self): def test_returns_sync_context(self): context = cpp_execution_contexts.create_sync_local_cpp_execution_context() - self.assertIsInstance(context, sync_execution_context.SyncExecutionContext) + self.assertIsInstance( + context, federated_language.framework.SyncExecutionContext + ) class SetSyncLocalCPPExecutionContextTest(absltest.TestCase): @@ -137,17 +133,20 @@ def test_returns_sync_context(self): context = cpp_execution_contexts.create_sync_remote_cpp_execution_context( channels=channels ) - self.assertIsInstance(context, sync_execution_context.SyncExecutionContext) + self.assertIsInstance( + context, federated_language.framework.SyncExecutionContext + ) def test_returns_same_python_structure(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( collections.OrderedDict(a=np.int32, b=np.float32) ) def identity(x): return x context = cpp_execution_contexts.create_sync_local_cpp_execution_context() - with get_context_stack.get_context_stack().install(context): + with federated_language.framework.get_context_stack().install(context): odict = identity(collections.OrderedDict(a=0, b=1.0)) self.assertIsInstance(odict, collections.OrderedDict) @@ -160,7 +159,7 @@ def multiply(ordered_dict): return ordered_dict['x'] * ordered_dict['y'] context = cpp_execution_contexts.create_sync_local_cpp_execution_context() - with get_context_stack.get_context_stack().install(context): + with federated_language.framework.get_context_stack().install(context): zero = multiply(collections.OrderedDict(x=0, y=1)) one = multiply(collections.OrderedDict(x=1, y=1)) @@ -217,16 +216,18 @@ def create_sequence(): return tf.data.Dataset.range(5) context = cpp_execution_contexts.create_sync_local_cpp_execution_context() - with get_context_stack.get_context_stack().install(context): + with federated_language.framework.get_context_stack().install(context): with self.subTest('unplaced'): sequence = create_sequence() self.assertEqual(sequence, [0, 1, 2, 3, 4]) with self.subTest('federated'): - @federated_computation.federated_computation + @federated_language.federated_computation def create_federated_sequence(): - return intrinsics.federated_eval(create_sequence, placements.SERVER) + return federated_language.federated_eval( + create_sequence, federated_language.SERVER + ) sequence = create_federated_sequence() self.assertEqual(sequence, [0, 1, 2, 3, 4]) diff --git a/tensorflow_federated/python/core/backends/native/execution_contexts.py b/tensorflow_federated/python/core/backends/native/execution_contexts.py index f09be387e2..21b071c519 100644 --- a/tensorflow_federated/python/core/backends/native/execution_contexts.py +++ b/tensorflow_federated/python/core/backends/native/execution_contexts.py @@ -16,19 +16,17 @@ from collections.abc import Sequence from typing import Optional +import federated_language + from tensorflow_federated.python.core.backends.native import compiler from tensorflow_federated.python.core.backends.native import mergeable_comp_compiler from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context from tensorflow_federated.python.core.impl.execution_contexts import mergeable_comp_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executor_stacks import executor_factory def create_mergeable_comp_execution_context( - async_contexts: Sequence[context_base.AsyncContext], + async_contexts: Sequence[federated_language.framework.AsyncContext], num_subrounds: Optional[int] = None, ) -> mergeable_comp_execution_context.MergeableCompExecutionContext: """Creates context which compiles to and executes mergeable comp form. @@ -55,7 +53,7 @@ def create_mergeable_comp_execution_context( def set_mergeable_comp_execution_context( - async_contexts: Sequence[context_base.AsyncContext], + async_contexts: Sequence[federated_language.framework.AsyncContext], num_subrounds: Optional[int] = None, ): """Sets context which compiles to and executes mergeable comp form. @@ -73,14 +71,14 @@ def set_mergeable_comp_execution_context( async_contexts=async_contexts, num_subrounds=num_subrounds, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) def create_async_local_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> async_execution_context.AsyncExecutionContext: +) -> federated_language.framework.AsyncExecutionContext: """Returns an execution context backed by C++ runtime. This execution context starts a C++ worker assumed to be at path @@ -103,7 +101,7 @@ def create_async_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - return async_execution_context.AsyncExecutionContext( + return federated_language.framework.AsyncExecutionContext( executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, @@ -122,14 +120,14 @@ def set_async_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) def create_sync_local_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> sync_execution_context.SyncExecutionContext: +) -> federated_language.framework.SyncExecutionContext: """Returns an execution context backed by C++ runtime. This execution context starts a C++ worker assumed to be at path @@ -153,7 +151,7 @@ def create_sync_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - return sync_execution_context.SyncExecutionContext( + return federated_language.framework.SyncExecutionContext( executor_fn=factory, compiler_fn=compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, @@ -172,4 +170,4 @@ def set_sync_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) diff --git a/tensorflow_federated/python/core/backends/native/execution_contexts_test.py b/tensorflow_federated/python/core/backends/native/execution_contexts_test.py index 27c91653d4..58fbb4c183 100644 --- a/tensorflow_federated/python/core/backends/native/execution_contexts_test.py +++ b/tensorflow_federated/python/core/backends/native/execution_contexts_test.py @@ -20,13 +20,13 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import tensorflow as tf from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_impl from tensorflow_federated.python.core.impl.executors import remote_executor_grpc_stub @@ -93,7 +93,7 @@ def foo(): def _create_mock_remote_executor_grpc_stub( - computation: computation_impl.ConcreteComputation, + computation: federated_language.framework.ConcreteComputation, ) -> remote_executor_grpc_stub.RemoteExecutorGrpcStub: class _GetExecutorResponse: diff --git a/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler.py b/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler.py index ce47584888..f40ab2b887 100644 --- a/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler.py +++ b/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler.py @@ -13,22 +13,14 @@ # limitations under the License. """A MergeableCompForm compiler for the native backend.""" +import federated_language from tensorflow_federated.python.core.backends.mapreduce import compiler from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import transformations -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl from tensorflow_federated.python.core.impl.execution_contexts import mergeable_comp_execution_context -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _compile_to_tf(fn): @@ -38,48 +30,51 @@ def _compile_to_tf(fn): def _select_output_result_and_wrap_as_noarg_tensorflow( - fn: building_blocks.Lambda, path: building_block_factory.Path -) -> computation_impl.ConcreteComputation: - selected_and_wrapped = building_blocks.Lambda( + fn: federated_language.framework.Lambda, + path: federated_language.framework.Path, +) -> federated_language.framework.ConcreteComputation: + selected_and_wrapped = federated_language.framework.Lambda( None, None, - building_block_factory.select_output_from_lambda(fn, path).result, + federated_language.framework.select_output_from_lambda(fn, path).result, ) selected_and_compiled = _compile_to_tf(selected_and_wrapped) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=selected_and_compiled.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) def _select_output_result_and_wrap_as_tensorflow( - fn: building_blocks.Lambda, path: building_block_factory.Path -) -> computation_impl.ConcreteComputation: - selected_fn = building_block_factory.select_output_from_lambda( + fn: federated_language.framework.Lambda, + path: federated_language.framework.Path, +) -> federated_language.framework.ConcreteComputation: + selected_fn = federated_language.framework.select_output_from_lambda( fn, path ).result selected_and_compiled = _compile_to_tf(selected_fn) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=selected_and_compiled.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) def _extract_federated_aggregate_computations( - before_agg: building_blocks.Lambda, + before_agg: federated_language.framework.Lambda, ): """Extracts aggregate computations from `before_agg`. Args: - before_agg: a `building_blocks.ComputationBuildingBlock` representing the - before-aggregate portion of a computation split on `federated_aggregate`. + before_agg: a `federated_language.framework.ComputationBuildingBlock` + representing the before-aggregate portion of a computation split on + `federated_aggregate`. Returns: A tuple of four ConcreteComputations corresponding to the aggregate functions in `before_agg`. """ federated_aggregate_arguments = ( - building_block_factory.select_output_from_lambda( + federated_language.framework.select_output_from_lambda( before_agg, 'federated_aggregate_param' ) ) @@ -104,22 +99,22 @@ def _extract_federated_aggregate_computations( def _ensure_lambda( - building_block: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Lambda: + building_block: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.Lambda: """Wraps a functional building block as a lambda if necessary.""" if not isinstance( - building_block.type_signature, computation_types.FunctionType + building_block.type_signature, federated_language.FunctionType ): raise ValueError( f'Expected a `tff.FunctionType`, found {building_block.type_signature}.' ) - if not isinstance(building_block, building_blocks.Lambda): + if not isinstance(building_block, federated_language.framework.Lambda): if building_block.type_signature.parameter is not None: # pytype: disable=attribute-error - name_generator = building_block_factory.unique_name_generator( + name_generator = federated_language.framework.unique_name_generator( building_block ) parameter_name = next(name_generator) - argument = building_blocks.Reference( + argument = federated_language.framework.Reference( parameter_name, building_block.type_signature.parameter, # pytype: disable=attribute-error ) @@ -128,15 +123,15 @@ def _ensure_lambda( argument = None parameter_type = None parameter_name = None - result = building_blocks.Call(building_block, argument) - building_block = building_blocks.Lambda( + result = federated_language.framework.Call(building_block, argument) + building_block = federated_language.framework.Lambda( parameter_name, parameter_type, result ) return building_block def compile_to_mergeable_comp_form( - comp: computation_impl.ConcreteComputation, + comp: federated_language.framework.ConcreteComputation, ) -> mergeable_comp_execution_context.MergeableCompForm: """Compiles a computation with a single aggregation to `MergeableCompForm`. @@ -145,11 +140,11 @@ def compile_to_mergeable_comp_form( instance of `mergeable_comp_execution_context.MergeableCompForm`. Args: - comp: Instance of `computation_impl.ConcreteComputation` to compile. Assumed - to be representable as a computation with a single aggregation in its - body, so that for example two parallel aggregations are allowed, but - multiple dependent aggregations are disallowed. Additionally assumed to be - of functional type. + comp: Instance of `federated_language.framework.ConcreteComputation` to + compile. Assumed to be representable as a computation with a single + aggregation in its body, so that for example two parallel aggregations are + allowed, but multiple dependent aggregations are disallowed. Additionally + assumed to be of functional type. Returns: A semantically equivalent instance of @@ -171,7 +166,7 @@ def compile_to_mergeable_comp_form( # We transform the body of this computation to easily preserve the top-level # lambda required by force-aligning. call_dominant_body_bb = transformations.to_call_dominant(lowered_bb.result) - call_dominant_bb = building_blocks.Lambda( + call_dominant_bb = federated_language.framework.Lambda( lowered_bb.parameter_name, lowered_bb.parameter_type, call_dominant_body_bb, @@ -179,7 +174,9 @@ def compile_to_mergeable_comp_form( # This check should not throw false positives because we just ensured we are # in call-dominant form. - tree_analysis.check_aggregate_not_dependent_on_aggregate(call_dominant_bb) + federated_language.framework.check_aggregate_not_dependent_on_aggregate( + call_dominant_bb + ) before_agg, after_agg = transformations.force_align_and_split_by_intrinsics( call_dominant_bb, @@ -193,7 +190,7 @@ def compile_to_mergeable_comp_form( report_proto, report_type = tensorflow_computation_factory.create_identity( merge_fn_type.result ) - identity_report = building_blocks.CompiledComputation( + identity_report = federated_language.framework.CompiledComputation( report_proto, type_signature=report_type ) @@ -201,20 +198,20 @@ def compile_to_mergeable_comp_form( _extract_federated_aggregate_computations(before_agg) ) - before_agg_callable = computation_impl.ConcreteComputation( + before_agg_callable = federated_language.framework.ConcreteComputation( computation_proto=before_agg.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) - after_agg_callable = computation_impl.ConcreteComputation( + after_agg_callable = federated_language.framework.ConcreteComputation( computation_proto=after_agg.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) if before_agg.type_signature.parameter is not None: # TODO: b/147499373 - If None-arguments were uniformly represented as empty # tuples, we would be able to avoid this (and related) ugly casing. - @federated_computation.federated_computation( + @federated_language.federated_computation( before_agg.type_signature.parameter ) def up_to_merge_computation(arg): @@ -223,49 +220,55 @@ def up_to_merge_computation(arg): ] value_to_aggregate = federated_aggregate_args[0] zero = zero_comp() - return intrinsics.federated_aggregate( + return federated_language.federated_aggregate( value_to_aggregate, zero, accumulate_comp, merge_comp, identity_report ) - @federated_computation.federated_computation( + @federated_language.federated_computation( before_agg.type_signature.parameter, - computation_types.FederatedType( + federated_language.FederatedType( identity_report.type_signature.result, # pytype: disable=attribute-error - placements.SERVER, + federated_language.SERVER, ), ) def after_merge_computation(top_level_arg, merge_result): - reported_result = intrinsics.federated_map(report_comp, merge_result) + reported_result = federated_language.federated_map( + report_comp, merge_result + ) return after_agg_callable(top_level_arg, [reported_result]) else: - @federated_computation.federated_computation() + @federated_language.federated_computation() def up_to_merge_computation(): federated_aggregate_args = before_agg_callable()[ 'federated_aggregate_param' ] value_to_aggregate = federated_aggregate_args[0] zero = zero_comp() - return intrinsics.federated_aggregate( + return federated_language.federated_aggregate( value_to_aggregate, zero, accumulate_comp, merge_comp, identity_report ) - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( identity_report.type_signature.result, # pytype: disable=attribute-error - placements.SERVER, + federated_language.SERVER, ) ) def after_merge_computation(merge_result): - reported_result = intrinsics.federated_map(report_comp, merge_result) + reported_result = federated_language.federated_map( + report_comp, merge_result + ) return after_agg_callable([[reported_result]]) - annotated_type_signature = computation_types.FunctionType( + annotated_type_signature = federated_language.FunctionType( after_merge_computation.type_signature.parameter, original_return_type ) - after_merge_computation = computation_impl.ConcreteComputation.with_type( - after_merge_computation, annotated_type_signature + after_merge_computation = ( + federated_language.framework.ConcreteComputation.with_type( + after_merge_computation, annotated_type_signature + ) ) return mergeable_comp_execution_context.MergeableCompForm( diff --git a/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py b/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py index 9a1e744f07..a87bba9119 100644 --- a/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py +++ b/tensorflow_federated/python/core/backends/native/mergeable_comp_compiler_test.py @@ -13,23 +13,18 @@ # limitations under the License. from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.backends.native import mergeable_comp_compiler from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context from tensorflow_federated.python.core.impl.execution_contexts import mergeable_comp_execution_context from tensorflow_federated.python.core.impl.executor_stacks import executor_factory # pylint: enable=line-too-long -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils def _create_test_context(): factory = executor_factory.local_cpp_executor_factory() - context = async_execution_context.AsyncExecutionContext( + context = federated_language.framework.AsyncExecutionContext( executor_fn=factory, transform_args=tensorflow_computation.transform_args, transform_result=tensorflow_computation.transform_result, @@ -48,12 +43,10 @@ def build_whimsy_computation_with_aggregation_and_after( def compute_sum(x, y): return x + y - @federated_computation.federated_computation( - server_arg_type, clients_arg_type - ) + @federated_language.federated_computation(server_arg_type, clients_arg_type) def aggregation_comp(server_arg, client_arg): - summed_client_value = intrinsics.federated_sum(client_arg) - return intrinsics.federated_map( + summed_client_value = federated_language.federated_sum(client_arg) + return federated_language.federated_map( compute_sum, (server_arg, summed_client_value) ) @@ -73,13 +66,13 @@ def compute_tuple_sum(x): def compute_sum(x, y): return x + y - @federated_computation.federated_computation( - server_arg_type, clients_arg_type - ) + @federated_language.federated_computation(server_arg_type, clients_arg_type) def aggregation_comp(server_arg, client_arg): - client_sums = intrinsics.federated_map(compute_tuple_sum, client_arg) - summed_client_value = intrinsics.federated_sum(client_sums) - return intrinsics.federated_map( + client_sums = federated_language.federated_map( + compute_tuple_sum, client_arg + ) + summed_client_value = federated_language.federated_sum(client_sums) + return federated_language.federated_map( compute_sum, (server_arg, summed_client_value) ) @@ -99,22 +92,22 @@ def compute_tuple_sum(x): def compute_sum(x, y): return x + y - @federated_computation.federated_computation + @federated_language.federated_computation def package_args_as_tuple(x, y): return [x, y] - @federated_computation.federated_computation( - server_arg_type, clients_arg_type - ) + @federated_language.federated_computation(server_arg_type, clients_arg_type) def aggregation_comp(server_arg, client_arg): - client_sums = intrinsics.federated_map(compute_tuple_sum, client_arg) - summed_client_value = intrinsics.federated_sum(client_sums) - broadcast_sum = intrinsics.federated_broadcast(summed_client_value) + client_sums = federated_language.federated_map( + compute_tuple_sum, client_arg + ) + summed_client_value = federated_language.federated_sum(client_sums) + broadcast_sum = federated_language.federated_broadcast(summed_client_value) # Adding a function call here requires normalization into CDF before # checking the aggregation-dependence condition. client_tuple = package_args_as_tuple(client_sums, broadcast_sum) - summed_client_value = intrinsics.federated_sum(client_tuple[0]) - return intrinsics.federated_map( + summed_client_value = federated_language.federated_sum(client_tuple[0]) + return federated_language.federated_map( compute_sum, (server_arg, summed_client_value) ) @@ -126,16 +119,18 @@ def tf_multiply_int(x, y): return x * y -@federated_computation.federated_computation(np.int32, np.int32) +@federated_language.federated_computation(np.int32, np.int32) def return_list(x, y): return [x, y] -@federated_computation.federated_computation( - computation_types.FederatedType([np.int32, np.int32], placements.SERVER) +@federated_language.federated_computation( + federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER + ) ) def server_placed_mult(arg): - return intrinsics.federated_map(tf_multiply_int, arg) + return federated_language.federated_map(tf_multiply_int, arg) class MergeableCompCompilerTest(absltest.TestCase): @@ -153,14 +148,14 @@ def _invoke_mergeable_form_on_arg( def test_raises_two_dependent_aggregates(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def dependent_agg_comp(server_arg): - arg_at_clients = intrinsics.federated_broadcast(server_arg) - sum_result = intrinsics.federated_sum(arg_at_clients) - rebroadcast_sum = intrinsics.federated_broadcast(sum_result) - return intrinsics.federated_sum(rebroadcast_sum) + arg_at_clients = federated_language.federated_broadcast(server_arg) + sum_result = federated_language.federated_sum(arg_at_clients) + rebroadcast_sum = federated_language.federated_broadcast(sum_result) + return federated_language.federated_sum(rebroadcast_sum) with self.assertRaisesRegex( ValueError, 'one aggregate dependent on another' @@ -175,7 +170,7 @@ def test_preserves_python_containers_in_after_merge(self): self.assertIsInstance( mergeable_form, mergeable_comp_execution_context.MergeableCompForm ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( mergeable_form.after_merge.type_signature.result, return_list.type_signature.result, ) @@ -203,9 +198,10 @@ def test_compilation_preserves_semantics_standalone_tf(self): self.assertEqual(expected_six, 6) def test_compiles_simple_noarg_computation(self): - @federated_computation.federated_computation() + + @federated_language.federated_computation() def return_server_value(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( return_server_value @@ -216,9 +212,10 @@ def return_server_value(): ) def test_preserves_semantics_of_noarg_computation(self): - @federated_computation.federated_computation() + + @federated_language.federated_computation() def return_server_value(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( return_server_value @@ -251,8 +248,8 @@ def test_compilation_preserves_semantics_server_placed_computation(self): def test_compiles_computation_with_aggregation_and_after(self): incoming_comp = build_whimsy_computation_with_aggregation_and_after( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( incoming_comp @@ -264,8 +261,8 @@ def test_compiles_computation_with_aggregation_and_after(self): def test_compilation_preserves_semantics_aggregation_and_after(self): incoming_comp = build_whimsy_computation_with_aggregation_and_after( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( incoming_comp @@ -278,9 +275,9 @@ def test_compilation_preserves_semantics_aggregation_and_after(self): def test_compiles_computation_with_before_aggregation_work(self): incoming_comp = build_whimsy_computation_with_before_aggregation_work( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( @@ -293,9 +290,9 @@ def test_compiles_computation_with_before_aggregation_work(self): def test_compiles_computation_with_false_aggregation_dependence(self): incoming_comp = build_whimsy_computation_with_false_aggregation_dependence( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( @@ -308,9 +305,9 @@ def test_compiles_computation_with_false_aggregation_dependence(self): def test_compilation_preserves_semantics_before_agg_work(self): incoming_comp = build_whimsy_computation_with_before_aggregation_work( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ) mergeable_form = mergeable_comp_compiler.compile_to_mergeable_comp_form( diff --git a/tensorflow_federated/python/core/backends/test/BUILD b/tensorflow_federated/python/core/backends/test/BUILD index 6af599223f..596c14c60d 100644 --- a/tensorflow_federated/python/core/backends/test/BUILD +++ b/tensorflow_federated/python/core/backends/test/BUILD @@ -38,14 +38,7 @@ py_library( "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_building_block_factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_computation_factory", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_conversions", + "@federated_language//federated_language", ], ) @@ -56,12 +49,7 @@ py_test( ":compiler", "//tensorflow_federated/python/core/backends/mapreduce:intrinsics", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_tree_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -76,11 +64,9 @@ py_library( "//tensorflow_federated/python/core/backends/native:compiler", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_executor_bindings", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:cpp_executor_factory", "//tensorflow_federated/python/core/impl/executors:executor_bindings", + "@federated_language//federated_language", ], ) @@ -94,12 +80,7 @@ py_test( ":cpp_execution_contexts", ":execution_contexts", "//tensorflow_federated/python/core/backends/mapreduce:intrinsics", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -110,10 +91,8 @@ py_library( ":compiler", "//tensorflow_federated/python/core/backends/native:compiler", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:executor_factory", + "@federated_language//federated_language", ], ) @@ -123,9 +102,6 @@ py_test( deps = [ ":execution_contexts", "//tensorflow_federated/python/core/backends/mapreduce:intrinsics", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/backends/test/compiler.py b/tensorflow_federated/python/core/backends/test/compiler.py index a18a292bfc..359bc03241 100644 --- a/tensorflow_federated/python/core/backends/test/compiler.py +++ b/tensorflow_federated/python/core/backends/test/compiler.py @@ -16,6 +16,7 @@ import collections from collections.abc import Callable +import federated_language import numpy as np import tensorflow as tf @@ -25,22 +26,14 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_conversions def _ensure_structure( int_or_structure, int_or_structure_type, possible_struct_type ): if isinstance( - int_or_structure_type, computation_types.StructType - ) or not isinstance(possible_struct_type, computation_types.StructType): + int_or_structure_type, federated_language.StructType + ) or not isinstance(possible_struct_type, federated_language.StructType): return int_or_structure else: # Broadcast int_or_structure to the same structure as the struct type @@ -52,8 +45,8 @@ def _ensure_structure( def _get_secure_intrinsic_reductions() -> dict[ str, Callable[ - [building_blocks.ComputationBuildingBlock], - building_blocks.ComputationBuildingBlock, + [federated_language.framework.ComputationBuildingBlock], + federated_language.framework.ComputationBuildingBlock, ], ]: """Returns map from intrinsic to reducing function. @@ -74,10 +67,12 @@ def _get_secure_intrinsic_reductions() -> dict[ """ def federated_secure_sum(arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - summand_arg = building_blocks.Selection(arg, index=0) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) + summand_arg = federated_language.framework.Selection(arg, index=0) summand_type = summand_arg.type_signature.member # pytype: disable=attribute-error - max_input_arg = building_blocks.Selection(arg, index=1) + max_input_arg = federated_language.framework.Selection(arg, index=1) max_input_type = max_input_arg.type_signature # Add the max_value as a second value in the zero, so it can be read during @@ -89,7 +84,7 @@ def federated_secure_sum(arg): summation_zero = tensorflow_building_block_factory.create_generic_constant( summand_type, 0 ) - aggregation_zero = building_blocks.Struct( + aggregation_zero = federated_language.framework.Struct( [summation_zero, max_input_arg], container_type=tuple ) @@ -132,9 +127,11 @@ def assert_all_coordinates_less_equal(x, m): second_operand_type=summand_type, ) ) - assert_less_equal_and_add = building_blocks.CompiledComputation( - assert_less_equal_and_add_proto, - type_signature=assert_less_equal_and_add_type, + assert_less_equal_and_add = ( + federated_language.framework.CompiledComputation( + assert_less_equal_and_add_proto, + type_signature=assert_less_equal_and_add_type, + ) ) def nested_plus(a, b): @@ -145,7 +142,7 @@ def nested_plus(a, b): nested_plus, operand_type=aggregation_zero.type_signature ) ) - plus_op = building_blocks.CompiledComputation( + plus_op = federated_language.framework.CompiledComputation( plus_proto, type_signature=plus_type ) @@ -153,15 +150,17 @@ def nested_plus(a, b): # of the struct (which was holding the max_value). drop_max_value_proto, drop_max_value_type = ( tensorflow_computation_factory.create_unary_operator( - lambda x: type_conversions.type_to_py_container(x[0], summand_type), + lambda x: federated_language.framework.type_to_py_container( + x[0], summand_type + ), aggregation_zero.type_signature, ) ) - drop_max_value_op = building_blocks.CompiledComputation( + drop_max_value_op = federated_language.framework.CompiledComputation( drop_max_value_proto, type_signature=drop_max_value_type ) - return building_block_factory.create_federated_aggregate( + return federated_language.framework.create_federated_aggregate( summand_arg, aggregation_zero, assert_less_equal_and_add, @@ -170,9 +169,11 @@ def nested_plus(a, b): ) def federated_secure_sum_bitwidth(arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - summand_arg = building_blocks.Selection(arg, index=0) - bitwidth_arg = building_blocks.Selection(arg, index=1) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) + summand_arg = federated_language.framework.Selection(arg, index=0) + bitwidth_arg = federated_language.framework.Selection(arg, index=1) # Comptue the max_input value from the provided bitwidth. def max_input_from_bitwidth(bitwidth): @@ -200,34 +201,40 @@ def compute_max_input(bits): max_input_from_bitwidth, bitwidth_arg.type_signature ) ) - compute_max_value_op = building_blocks.CompiledComputation( + compute_max_value_op = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) - max_value = building_blocks.Call(compute_max_value_op, bitwidth_arg) + max_value = federated_language.framework.Call( + compute_max_value_op, bitwidth_arg + ) return federated_secure_sum( - building_blocks.Struct([summand_arg, max_value]) + federated_language.framework.Struct([summand_arg, max_value]) ) def federated_secure_modular_sum(arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - if not isinstance(arg.type_signature, computation_types.StructType): + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) + if not isinstance(arg.type_signature, federated_language.StructType): raise ValueError( f'Expected a `tff.StructType`, found {arg.type_signature}.' ) - if isinstance(arg.type_signature, computation_types.StructWithPythonType): + if isinstance(arg.type_signature, federated_language.StructWithPythonType): container_type = arg.type_signature.python_container else: container_type = None - summand_arg = building_blocks.Selection(arg, index=0) - raw_summed_values = building_block_factory.create_federated_sum(summand_arg) + summand_arg = federated_language.framework.Selection(arg, index=0) + raw_summed_values = federated_language.framework.create_federated_sum( + summand_arg + ) - unplaced_modulus = building_blocks.Selection(arg, index=1) - placed_modulus = building_block_factory.create_federated_value( - unplaced_modulus, placements.SERVER + unplaced_modulus = federated_language.framework.Selection(arg, index=1) + placed_modulus = federated_language.framework.create_federated_value( + unplaced_modulus, federated_language.SERVER ) - modulus_arg = building_block_factory.create_federated_zip( - building_blocks.Struct( + modulus_arg = federated_language.framework.create_federated_zip( + federated_language.framework.Struct( [raw_summed_values, placed_modulus], container_type=container_type ) ) @@ -247,22 +254,24 @@ def map_structure_mod(summed_values, modulus): second_operand_type=placed_modulus.type_signature.member, # pytype: disable=attribute-error ) ) - modulus_fn = building_blocks.CompiledComputation( + modulus_fn = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) - modulus_computed = building_block_factory.create_federated_apply( + modulus_computed = federated_language.framework.create_federated_apply( modulus_fn, modulus_arg ) return modulus_computed def federated_secure_select(arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - client_keys_arg = building_blocks.Selection(arg, index=0) - max_key_arg = building_blocks.Selection(arg, index=1) - server_val_arg = building_blocks.Selection(arg, index=2) - select_fn_arg = building_blocks.Selection(arg, index=3) - return building_block_factory.create_federated_select( + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) + client_keys_arg = federated_language.framework.Selection(arg, index=0) + max_key_arg = federated_language.framework.Selection(arg, index=1) + server_val_arg = federated_language.framework.Selection(arg, index=2) + select_fn_arg = federated_language.framework.Selection(arg, index=3) + return federated_language.framework.create_federated_select( client_keys_arg, max_key_arg, server_val_arg, @@ -272,15 +281,21 @@ def federated_secure_select(arg): secure_intrinsic_bodies_by_uri = collections.OrderedDict([ ( - intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri, + federated_language.framework.FEDERATED_SECURE_SUM_BITWIDTH.uri, federated_secure_sum_bitwidth, ), ( mapreduce_intrinsics.FEDERATED_SECURE_MODULAR_SUM.uri, federated_secure_modular_sum, ), - (intrinsic_defs.FEDERATED_SECURE_SUM.uri, federated_secure_sum), - (intrinsic_defs.FEDERATED_SECURE_SELECT.uri, federated_secure_select), + ( + federated_language.framework.FEDERATED_SECURE_SUM.uri, + federated_secure_sum, + ), + ( + federated_language.framework.FEDERATED_SECURE_SELECT.uri, + federated_secure_select, + ), ]) return secure_intrinsic_bodies_by_uri @@ -289,7 +304,7 @@ def _replace_secure_intrinsics_with_insecure_bodies(comp): """Iterates over all secure intrinsic bodies, inlining the intrinsics. This function operates on the AST level; meaning, it takes in a - `building_blocks.ComputationBuildingBlock` as an argument and + `federated_language.framework.ComputationBuildingBlock` as an argument and returns one as well. `replace_intrinsics_with_bodies` is intended to be the standard reduction function, which will reduce all currently implemented intrinsics to their bodies. @@ -299,18 +314,20 @@ def _replace_secure_intrinsics_with_insecure_bodies(comp): function is ordered from more complex intrinsic to less complex intrinsics. Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` in which we - wish to replace all intrinsics with their bodies. + comp: Instance of `federated_language.framework.ComputationBuildingBlock` in + which we wish to replace all intrinsics with their bodies. Returns: - Instance of `building_blocks.ComputationBuildingBlock` with all + Instance of `federated_language.framework.ComputationBuildingBlock` with all the intrinsics from `intrinsic_bodies.py` inlined with their bodies, along with a Boolean indicating whether there was any inlining in fact done. Raises: TypeError: If the types don't match. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) secure_bodies = _get_secure_intrinsic_reductions() transformed = False for uri, body in secure_bodies.items(): @@ -340,7 +357,7 @@ def replace_secure_intrinsics_with_bodies(comp): replaced_intrinsic_bodies, _ = ( _replace_secure_intrinsics_with_insecure_bodies(comp.to_building_block()) ) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=replaced_intrinsic_bodies.proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) diff --git a/tensorflow_federated/python/core/backends/test/compiler_test.py b/tensorflow_federated/python/core/backends/test/compiler_test.py index f0a49b7aca..b6e1e5a141 100644 --- a/tensorflow_federated/python/core/backends/test/compiler_test.py +++ b/tensorflow_federated/python/core/backends/test/compiler_test.py @@ -14,28 +14,23 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.core.backends.mapreduce import intrinsics from tensorflow_federated.python.core.backends.test import compiler from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils def _count_intrinsics(comp, uri): def _predicate(comp): return ( - isinstance(comp, building_blocks.Intrinsic) + isinstance(comp, federated_language.framework.Intrinsic) and uri is not None and comp.uri == uri ) - return tree_analysis.count(comp, _predicate) + return federated_language.framework.computation_count(comp, _predicate) class ReplaceIntrinsicsWithBodiesTest(parameterized.TestCase): @@ -48,17 +43,19 @@ class ReplaceIntrinsicsWithBodiesTest(parameterized.TestCase): ('per_leaf_bitwidth', [np.int64, [np.int32]], [np.int32, [np.int32]]), ) def test_federated_secure_sum(self, value_dtype, bitwidth_type): - uri = intrinsic_defs.FEDERATED_SECURE_SUM.uri - comp = building_blocks.Intrinsic( + uri = federated_language.framework.FEDERATED_SECURE_SUM.uri + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( + federated_language.FunctionType( [ - computation_types.FederatedType( - value_dtype, placements.CLIENTS + federated_language.FederatedType( + value_dtype, federated_language.CLIENTS ), - computation_types.to_type(bitwidth_type), + federated_language.to_type(bitwidth_type), ], - computation_types.FederatedType(value_dtype, placements.SERVER), + federated_language.FederatedType( + value_dtype, federated_language.SERVER + ), ), ) self.assertGreater(_count_intrinsics(comp, uri), 0) @@ -68,7 +65,7 @@ def test_federated_secure_sum(self, value_dtype, bitwidth_type): ) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) # Now replace bodies including secure intrinsics. @@ -76,11 +73,14 @@ def test_federated_secure_sum(self, value_dtype, bitwidth_type): compiler._replace_secure_intrinsics_with_insecure_bodies(comp) ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater( - _count_intrinsics(reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri), 0 + _count_intrinsics( + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri + ), + 0, ) @parameterized.named_parameters( @@ -91,18 +91,18 @@ def test_federated_secure_sum(self, value_dtype, bitwidth_type): ('per_leaf_bitwidth', [np.int64, [np.int32]], [np.int32, [np.int32]]), ) def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type): - uri = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri - comp = building_blocks.Intrinsic( + uri = federated_language.framework.FEDERATED_SECURE_SUM_BITWIDTH.uri + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( + federated_language.FunctionType( parameter=[ - computation_types.FederatedType( - value_dtype, placements.CLIENTS + federated_language.FederatedType( + value_dtype, federated_language.CLIENTS ), - computation_types.to_type(bitwidth_type), + federated_language.to_type(bitwidth_type), ], - result=computation_types.FederatedType( - value_dtype, placements.SERVER + result=federated_language.FederatedType( + value_dtype, federated_language.SERVER ), ), ) @@ -112,7 +112,7 @@ def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type): ) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) # Now replace bodies including secure intrinsics. @@ -120,11 +120,14 @@ def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type): compiler._replace_secure_intrinsics_with_insecure_bodies(comp) ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater( - _count_intrinsics(reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri), 0 + _count_intrinsics( + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri + ), + 0, ) @parameterized.named_parameters( @@ -136,17 +139,17 @@ def test_federated_secure_sum_bitwidth(self, value_dtype, bitwidth_type): ) def test_federated_secure_modular_sum(self, value_dtype, modulus_type): uri = intrinsics.FEDERATED_SECURE_MODULAR_SUM.uri - comp = building_blocks.Intrinsic( + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( + federated_language.FunctionType( parameter=[ - computation_types.FederatedType( - value_dtype, placements.CLIENTS + federated_language.FederatedType( + value_dtype, federated_language.CLIENTS ), - computation_types.to_type(modulus_type), + federated_language.to_type(modulus_type), ], - result=computation_types.FederatedType( - value_dtype, placements.SERVER + result=federated_language.FederatedType( + value_dtype, federated_language.SERVER ), ), ) @@ -156,7 +159,7 @@ def test_federated_secure_modular_sum(self, value_dtype, modulus_type): ) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) # Now replace bodies including secure intrinsics. @@ -166,34 +169,38 @@ def test_federated_secure_modular_sum(self, value_dtype, modulus_type): self.assertTrue(modified) # Inserting tensorflow, as we do here, does not preserve python containers # currently. - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( comp.type_signature, reduced.type_signature ) self.assertGreater( - _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SUM.uri), 0 + _count_intrinsics( + reduced, federated_language.framework.FEDERATED_SUM.uri + ), + 0, ) def test_federated_secure_select(self): - uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri - comp = building_blocks.Intrinsic( + uri = federated_language.framework.FEDERATED_SECURE_SELECT.uri + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( + federated_language.FunctionType( [ - computation_types.FederatedType( - np.int32, placements.CLIENTS + federated_language.FederatedType( + np.int32, federated_language.CLIENTS ), # client_keys - computation_types.FederatedType( - np.int32, placements.SERVER + federated_language.FederatedType( + np.int32, federated_language.SERVER ), # max_key - computation_types.FederatedType( - np.float32, placements.SERVER + federated_language.FederatedType( + np.float32, federated_language.SERVER ), # server_state - computation_types.FunctionType( + federated_language.FunctionType( [np.float32, np.int32], np.float32 ), # select_fn ], - computation_types.FederatedType( - computation_types.SequenceType(np.float32), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.float32), + federated_language.CLIENTS, ), ), ) @@ -204,7 +211,7 @@ def test_federated_secure_select(self): ) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) # Now replace bodies including secure intrinsics. @@ -212,11 +219,14 @@ def test_federated_secure_select(self): compiler._replace_secure_intrinsics_with_insecure_bodies(comp) ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater( - _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0 + _count_intrinsics( + reduced, federated_language.framework.FEDERATED_SELECT.uri + ), + 0, ) diff --git a/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py index 322ac3ca2d..bd2ba9d3d0 100644 --- a/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/test/cpp_execution_contexts.py @@ -21,15 +21,13 @@ from absl import flags from absl import logging +import federated_language import portpicker from tensorflow_federated.python.core.backends.native import compiler as native_compiler from tensorflow_federated.python.core.backends.test import compiler as test_compiler from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_executor_bindings from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executor_stacks import cpp_executor_factory from tensorflow_federated.python.core.impl.executors import executor_bindings @@ -57,7 +55,7 @@ def create_async_test_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> async_execution_context.AsyncExecutionContext: +) -> federated_language.framework.AsyncExecutionContext: """Creates an async execution context for local testing of computations. Test execution contexts are useful for simulating the behavior of secure @@ -93,7 +91,7 @@ def _compile(comp): max_concurrent_computation_calls=max_concurrent_computation_calls, leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) - context = async_execution_context.AsyncExecutionContext( + context = federated_language.framework.AsyncExecutionContext( executor_fn=factory, compiler_fn=_compile, transform_args=tensorflow_computation.transform_args, @@ -131,14 +129,14 @@ def set_async_test_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) def create_sync_interprocess_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> sync_execution_context.SyncExecutionContext: +) -> federated_language.framework.SyncExecutionContext: """Creates an execution context backed by TFF-C++ runtime. This execution context starts a TFF-C++ worker in a subprocess on the local @@ -229,7 +227,7 @@ def initialize_channel(self) -> None: f'localhost:{port}' ) - return sync_execution_context.SyncExecutionContext( + return federated_language.framework.SyncExecutionContext( executor_fn=ManagedServiceContext(), compiler_fn=native_compiler.desugar_and_transform_to_native, transform_args=tensorflow_computation.transform_args, @@ -242,7 +240,7 @@ def create_sync_test_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> sync_execution_context.SyncExecutionContext: +) -> federated_language.framework.SyncExecutionContext: """Creates an execution context for local testing of computations. Test execution contexts are useful for simulating the behavior of secure @@ -278,7 +276,7 @@ def _compile(comp): max_concurrent_computation_calls=max_concurrent_computation_calls, leaf_executor_fn=_create_tensorflow_backend_execution_stack, ) - context = sync_execution_context.SyncExecutionContext( + context = federated_language.framework.SyncExecutionContext( executor_fn=factory, compiler_fn=_compile, transform_args=tensorflow_computation.transform_args, @@ -316,4 +314,4 @@ def set_sync_test_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) diff --git a/tensorflow_federated/python/core/backends/test/cpp_execution_contexts_test.py b/tensorflow_federated/python/core/backends/test/cpp_execution_contexts_test.py index 26d017f8c5..f2c5eb9250 100644 --- a/tensorflow_federated/python/core/backends/test/cpp_execution_contexts_test.py +++ b/tensorflow_federated/python/core/backends/test/cpp_execution_contexts_test.py @@ -18,18 +18,13 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tree from tensorflow_federated.python.core.backends.mapreduce import intrinsics as mapreduce_intrinsics from tensorflow_federated.python.core.backends.test import cpp_execution_contexts from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _assert_signature_equal(first_obj, second_obj): @@ -69,7 +64,7 @@ def test_has_same_signature(self): def test_returns_async_context(self): context = cpp_execution_contexts.create_async_test_cpp_execution_context() self.assertIsInstance( - context, async_execution_context.AsyncExecutionContext + context, federated_language.framework.AsyncExecutionContext ) @@ -92,7 +87,9 @@ def test_has_same_signature(self): def test_returns_sync_context(self): context = cpp_execution_contexts.create_sync_test_cpp_execution_context() - self.assertIsInstance(context, sync_execution_context.SyncExecutionContext) + self.assertIsInstance( + context, federated_language.framework.SyncExecutionContext + ) class SetSyncTestCPPExecutionContextTest(absltest.TestCase): @@ -111,16 +108,16 @@ class SecureModularSumTest( # pyformat: disable @parameterized.named_parameters( ('one_client_not_divisible', [1], 1, - computation_types.FederatedType(np.int32, placements.CLIENTS)), + federated_language.FederatedType(np.int32, federated_language.CLIENTS)), ('two_clients_none_divisible', [1, 2], 3, - computation_types.FederatedType(np.int32, placements.CLIENTS)), + federated_language.FederatedType(np.int32, federated_language.CLIENTS)), ('three_clients_one_divisible', [1, 2, 10], 3, - computation_types.FederatedType(np.int32, placements.CLIENTS)), + federated_language.FederatedType(np.int32, federated_language.CLIENTS)), ('all_clients_divisible_by_modulus', [x * 5 for x in range(5)], 0, - computation_types.FederatedType(np.int32, placements.CLIENTS)), + federated_language.FederatedType(np.int32, federated_language.CLIENTS)), ('nonscalar_struct_arg', [([1, 2], 3), ([4, 5], 6)], (np.array([0, 2], dtype=np.int32), 4), - computation_types.FederatedType(((np.int32, [2]), np.int32), placements.CLIENTS)), + federated_language.FederatedType(((np.int32, [2]), np.int32), federated_language.CLIENTS)), ) # pyformat: enable def test_executes_computation_with_modular_secure_sum_integer_modulus( @@ -130,7 +127,7 @@ def test_executes_computation_with_modular_secure_sum_integer_modulus( modulus = 5 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def modular_sum_by_five(arg): return mapreduce_intrinsics.federated_secure_modular_sum(arg, modulus) @@ -151,13 +148,17 @@ def modular_sum_by_five(arg): 'one_client_not_divisible', [1], 1, - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ), ( 'two_clients_none_divisible', [1, 2], 3, - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ), ) # pyformat: enable @@ -168,7 +169,7 @@ async def test_async_executes_computation_with_modular_secure_sum_integer_modulu modulus = 5 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def modular_sum_by_five(arg): return mapreduce_intrinsics.federated_secure_modular_sum(arg, modulus) @@ -179,32 +180,32 @@ def modular_sum_by_five(arg): 'one_client_not_divisible', [[1, 2]], [1, 2], - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ), ( 'two_clients_none_divisible', [[1, 2], [3, 4]], [4, 6], - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ), ( 'three_clients_one_divisible', [[1, 2], [3, 4], [10, 14]], [4, 6], - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ), ( 'two_clients_one_partially_divisible', [[1, 2], [3, 4], [10, 15]], [4, 0], - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ), ) @@ -214,7 +215,7 @@ def test_executes_computation_with_modular_secure_sum_struct_modulus( cpp_execution_contexts.set_sync_test_cpp_execution_context() modulus = [5, 7] - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def modular_sum_by_five(arg): return mapreduce_intrinsics.federated_secure_modular_sum(arg, modulus) @@ -237,11 +238,11 @@ def test_executes_computation_with_bitwidth_secure_sum_large_bitwidth( bitwidth = 32 expected_result = sum(arg) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def sum_with_bitwidth(arg): - return intrinsics.federated_secure_sum_bitwidth(arg, bitwidth) + return federated_language.federated_secure_sum_bitwidth(arg, bitwidth) self.assertEqual(expected_result, sum_with_bitwidth(arg)) @@ -253,24 +254,24 @@ async def test_async_executes_computation_with_bitwidth_secure_sum_large_bitwidt bitwidth = 32 expected_result = sum(arg) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def sum_with_bitwidth(arg): - return intrinsics.federated_secure_sum_bitwidth(arg, bitwidth) + return federated_language.federated_secure_sum_bitwidth(arg, bitwidth) self.assertEqual(expected_result, await sum_with_bitwidth(arg)) # pyformat: disable @parameterized.named_parameters( ('two_clients_scalar_tensors', [[1, 2], [3, 4]], [4, 6], - computation_types.FederatedType([np.int32, np.int32], placements.CLIENTS)), + federated_language.FederatedType([np.int32, np.int32], federated_language.CLIENTS)), ('two_clients_nonscalar_tensors', [[np.ones(shape=[10], dtype=np.int32), 2], [np.ones(shape=[10], dtype=np.int32), 4]], [2 * np.ones(shape=[10], dtype=np.int32), 6], - computation_types.FederatedType([ - computation_types.TensorType(dtype=np.int32, shape=[10]), np.int32], placements.CLIENTS) + federated_language.FederatedType([ + federated_language.TensorType(dtype=np.int32, shape=[10]), np.int32], federated_language.CLIENTS) ), ) # pyformat: enable @@ -281,9 +282,9 @@ def test_executes_computation_with_argument_structure( bitwidth = 32 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def sum_with_bitwidth(arg): - return intrinsics.federated_secure_sum_bitwidth(arg, bitwidth) + return federated_language.federated_secure_sum_bitwidth(arg, bitwidth) actual_result = sum_with_bitwidth(arg) @@ -303,11 +304,11 @@ def test_raises_with_arguments_over_max_value(self): max_value = 1 - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) with self.assertRaisesRegex( Exception, 'client value larger than maximum specified for secure sum' @@ -326,11 +327,11 @@ def test_executes_computation_with_secure_sum_under_max_values(self, arg): expected_result = sum(arg) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) self.assertEqual(expected_result, secure_sum(arg)) @@ -343,25 +344,25 @@ async def test_async_executes_computation_with_secure_sum_under_max_values( expected_result = sum(arg) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) self.assertEqual(expected_result, await secure_sum(arg)) # pyformat: disable @parameterized.named_parameters( ('two_clients_scalar_tensors', [[1, 2], [3, 4]], [4, 6], - computation_types.FederatedType([np.int32, np.int32], placements.CLIENTS)), + federated_language.FederatedType([np.int32, np.int32], federated_language.CLIENTS)), ('two_clients_nonscalar_tensors', [[np.ones(shape=[10], dtype=np.int32), 2], [np.ones(shape=[10], dtype=np.int32), 4]], [2 * np.ones(shape=[10], dtype=np.int32), 6], - computation_types.FederatedType([ - computation_types.TensorType( - dtype=np.int32, shape=[10]), np.int32], placements.CLIENTS)), + federated_language.FederatedType([ + federated_language.TensorType( + dtype=np.int32, shape=[10]), np.int32], federated_language.CLIENTS)), ) # pyformat: enable def test_executes_computation_with_argument_structure( @@ -371,9 +372,9 @@ def test_executes_computation_with_argument_structure( max_value = 100 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) actual_result = secure_sum(arg) diff --git a/tensorflow_federated/python/core/backends/test/execution_contexts.py b/tensorflow_federated/python/core/backends/test/execution_contexts.py index f62bb2008b..fd34a93c1e 100644 --- a/tensorflow_federated/python/core/backends/test/execution_contexts.py +++ b/tensorflow_federated/python/core/backends/test/execution_contexts.py @@ -13,12 +13,10 @@ # limitations under the License. """Execution contexts for the test backend.""" +import federated_language from tensorflow_federated.python.core.backends.native import compiler as native_compiler from tensorflow_federated.python.core.backends.test import compiler as test_compiler from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executor_stacks import executor_factory @@ -27,7 +25,7 @@ def create_async_test_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> async_execution_context.AsyncExecutionContext: +) -> federated_language.framework.AsyncExecutionContext: """Creates an execution context that executes computations locally.""" factory = executor_factory.local_cpp_executor_factory( default_num_clients=default_num_clients, @@ -42,7 +40,7 @@ def _compile(comp): comp = native_compiler.desugar_and_transform_to_native(comp) return comp - return async_execution_context.AsyncExecutionContext( + return federated_language.framework.AsyncExecutionContext( executor_fn=factory, compiler_fn=_compile, transform_args=tensorflow_computation.transform_args, @@ -62,7 +60,7 @@ def set_async_test_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) def create_sync_test_cpp_execution_context( @@ -70,7 +68,7 @@ def create_sync_test_cpp_execution_context( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> sync_execution_context.SyncExecutionContext: +) -> federated_language.framework.SyncExecutionContext: """Creates an execution context that executes computations locally.""" factory = executor_factory.local_cpp_executor_factory( default_num_clients=default_num_clients, @@ -85,7 +83,7 @@ def _compile(comp): comp = native_compiler.desugar_and_transform_to_native(comp) return comp - return sync_execution_context.SyncExecutionContext( + return federated_language.framework.SyncExecutionContext( executor_fn=factory, compiler_fn=_compile, transform_args=tensorflow_computation.transform_args, @@ -105,4 +103,4 @@ def set_sync_test_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, stream_structs=stream_structs, ) - context_stack_impl.context_stack.set_default_context(context) + federated_language.framework.global_context_stack.set_default_context(context) diff --git a/tensorflow_federated/python/core/backends/test/execution_contexts_test.py b/tensorflow_federated/python/core/backends/test/execution_contexts_test.py index 5e5456a15a..c78392c2b4 100644 --- a/tensorflow_federated/python/core/backends/test/execution_contexts_test.py +++ b/tensorflow_federated/python/core/backends/test/execution_contexts_test.py @@ -14,19 +14,18 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import grpc import numpy as np from tensorflow_federated.python.core.backends.mapreduce import intrinsics as mapreduce_intrinsics from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - -_CLIENTS_INT = computation_types.FederatedType(np.int32, placements.CLIENTS) -_CLIENTS_INT_LIST = computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + +_CLIENTS_INT = federated_language.FederatedType( + np.int32, federated_language.CLIENTS +) +_CLIENTS_INT_LIST = federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ) @@ -50,8 +49,8 @@ def setUp(self): 'nonscalar_struct_arg', [([1, 2], 3), ([4, 5], 6)], (np.array([0, 2], dtype=np.int32), 4), - computation_types.FederatedType( - ((np.int32, [2]), np.int32), placements.CLIENTS + federated_language.FederatedType( + ((np.int32, [2]), np.int32), federated_language.CLIENTS ), ), ) @@ -60,7 +59,7 @@ def test_executes_computation_with_modular_secure_sum_integer_modulus( ): modulus = 5 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def modular_sum_by_five(arg): return mapreduce_intrinsics.federated_secure_modular_sum(arg, modulus) @@ -101,7 +100,7 @@ def test_executes_computation_with_modular_secure_sum_struct_modulus( ): modulus = [5, 7] - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def modular_sum_by_five(arg): return mapreduce_intrinsics.federated_secure_modular_sum(arg, modulus) @@ -129,11 +128,11 @@ def test_executes_computation_with_bitwidth_secure_sum_large_bitwidth( expected_result = sum(arg) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def sum_with_bitwidth(arg): - return intrinsics.federated_secure_sum_bitwidth(arg, bitwidth) + return federated_language.federated_secure_sum_bitwidth(arg, bitwidth) self.assertEqual(expected_result, sum_with_bitwidth(arg)) @@ -142,8 +141,8 @@ def sum_with_bitwidth(arg): 'two_clients_scalar_tensors', [[1, 2], [3, 4]], [4, 6], - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ), ( @@ -153,12 +152,12 @@ def sum_with_bitwidth(arg): [np.ones(shape=[10], dtype=np.int32), 4], ], [2 * np.ones(shape=[10], dtype=np.int32), 6], - computation_types.FederatedType( + federated_language.FederatedType( [ - computation_types.TensorType(dtype=np.int32, shape=[10]), + federated_language.TensorType(dtype=np.int32, shape=[10]), np.int32, ], - placements.CLIENTS, + federated_language.CLIENTS, ), ), ) @@ -167,9 +166,9 @@ def test_executes_computation_with_argument_structure( ): bitwidth = 32 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def sum_with_bitwidth(arg): - return intrinsics.federated_secure_sum_bitwidth(arg, bitwidth) + return federated_language.federated_secure_sum_bitwidth(arg, bitwidth) actual_result = sum_with_bitwidth(arg) @@ -189,11 +188,11 @@ def setUp(self): def test_raises_with_arguments_over_max_value(self): max_value = 1 - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) with self.assertRaises(grpc.RpcError): secure_sum([2, 4]) @@ -211,11 +210,11 @@ def test_executes_computation_with_secure_sum_under_max_values(self, arg): expected_result = sum(arg) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) self.assertEqual(expected_result, secure_sum(arg)) @@ -224,8 +223,8 @@ def secure_sum(arg): 'two_clients_scalar_tensors', [[1, 2], [3, 4]], [4, 6], - computation_types.FederatedType( - [np.int32, np.int32], placements.CLIENTS + federated_language.FederatedType( + [np.int32, np.int32], federated_language.CLIENTS ), ), ( @@ -235,12 +234,12 @@ def secure_sum(arg): [np.ones(shape=[10], dtype=np.int32), 4], ], [2 * np.ones(shape=[10], dtype=np.int32), 6], - computation_types.FederatedType( + federated_language.FederatedType( [ - computation_types.TensorType(dtype=np.int32, shape=[10]), + federated_language.TensorType(dtype=np.int32, shape=[10]), np.int32, ], - placements.CLIENTS, + federated_language.CLIENTS, ), ), ) @@ -249,9 +248,9 @@ def test_executes_computation_with_argument_structure( ): max_value = 100 - @federated_computation.federated_computation(tff_type) + @federated_language.federated_computation(tff_type) def secure_sum(arg): - return intrinsics.federated_secure_sum(arg, max_value) + return federated_language.federated_secure_sum(arg, max_value) actual_result = secure_sum(arg) diff --git a/tensorflow_federated/python/core/backends/xla/BUILD b/tensorflow_federated/python/core/backends/xla/BUILD index 3edd790603..eb29e39d09 100644 --- a/tensorflow_federated/python/core/backends/xla/BUILD +++ b/tensorflow_federated/python/core/backends/xla/BUILD @@ -31,11 +31,9 @@ py_library( "//tensorflow_federated/python/core/backends/native:compiler", "//tensorflow_federated/python/core/environments/jax_frontend:jax_computation", "//tensorflow_federated/python/core/environments/xla_backend:xla_executor_bindings", - "//tensorflow_federated/python/core/impl/context_stack:set_default_context", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:cpp_executor_factory", "//tensorflow_federated/python/core/impl/executors:executor_bindings", + "@federated_language//federated_language", ], ) @@ -48,11 +46,6 @@ py_test( deps = [ ":cpp_execution_contexts", "//tensorflow_federated/python/core/environments/jax_frontend:jax_computation", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_test_utils", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py index 2040d7ce4b..f4001427d2 100644 --- a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py +++ b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts.py @@ -13,12 +13,10 @@ # limitations under the License. """Execution contexts for the XLA backend.""" +import federated_language from tensorflow_federated.python.core.backends.native import compiler from tensorflow_federated.python.core.environments.jax_frontend import jax_computation from tensorflow_federated.python.core.environments.xla_backend import xla_executor_bindings -from tensorflow_federated.python.core.impl.context_stack import set_default_context -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.execution_contexts import sync_execution_context from tensorflow_federated.python.core.impl.executor_stacks import cpp_executor_factory from tensorflow_federated.python.core.impl.executors import executor_bindings @@ -55,7 +53,7 @@ def create_async_local_cpp_execution_context( max_concurrent_computation_calls=max_concurrent_computation_calls, leaf_executor_fn=_create_xla_backend_execution_stack, ) - return async_execution_context.AsyncExecutionContext( + return federated_language.framework.AsyncExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form, transform_args=jax_computation.transform_args, @@ -70,7 +68,7 @@ def set_async_local_cpp_execution_context( default_num_clients=default_num_clients, max_concurrent_computation_calls=max_concurrent_computation_calls, ) - set_default_context.set_default_context(context) + federated_language.framework.set_default_context(context) def create_sync_local_cpp_execution_context( @@ -98,7 +96,7 @@ def create_sync_local_cpp_execution_context( # TODO: b/255978089 - implement lowering to federated_aggregate to create JAX # computations instead of TensorFlow, similar to "desugar intrinsics" in the # native backend. - return sync_execution_context.SyncExecutionContext( + return federated_language.framework.SyncExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form, transform_args=jax_computation.transform_args, @@ -113,4 +111,4 @@ def set_sync_local_cpp_execution_context( default_num_clients=default_num_clients, max_concurrent_computation_calls=max_concurrent_computation_calls, ) - set_default_context.set_default_context(context) + federated_language.framework.set_default_context(context) diff --git a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py index 18926a4537..f98cac7909 100644 --- a/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py +++ b/tensorflow_federated/python/core/backends/xla/cpp_execution_contexts_test.py @@ -15,16 +15,11 @@ import unittest from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.backends.xla import cpp_execution_contexts from tensorflow_federated.python.core.environments.jax_frontend import jax_computation -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements class AsyncLocalCppExecutionContextTest( @@ -33,9 +28,9 @@ class AsyncLocalCppExecutionContextTest( def test_create_async_local_cpp_execution_context_returns_async_context(self): context = cpp_execution_contexts.create_async_local_cpp_execution_context() - self.assertIsInstance(context, context_base.AsyncContext) + self.assertIsInstance(context, federated_language.framework.AsyncContext) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_async_local_cpp_execution_context ) async def test_jax_computation_returns_result(self): @@ -52,9 +47,9 @@ class SyncLocalCppExecutionContextTest(absltest.TestCase): def test_create_sync_local_cpp_execution_context_returns_sync_context(self): context = cpp_execution_contexts.create_sync_local_cpp_execution_context() - self.assertIsInstance(context, context_base.SyncContext) + self.assertIsInstance(context, federated_language.framework.SyncContext) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_sync_local_cpp_execution_context ) def test_jax_computation_returns_result(self): @@ -66,7 +61,7 @@ def _comp(a, b): self.assertEqual(actual_result, 3) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_sync_local_cpp_execution_context ) def test_federated_aggergate(self): @@ -85,11 +80,11 @@ def _identity(a): def zeros(): return np.float32(0) - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.CLIENTS) ) def aggregate(client_values): - return intrinsics.federated_aggregate( + return federated_language.federated_aggregate( client_values, zero=zeros(), accumulate=_add, @@ -101,14 +96,14 @@ def aggregate(client_values): aggregate([np.float32(1), np.float32(2), np.float32(3)]), np.float32(6) ) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_sync_local_cpp_execution_context ) def test_sequence_reduce(self): sequence = list(range(10)) - @federated_computation.federated_computation( - computation_types.SequenceType(np.int32) + @federated_language.federated_computation( + federated_language.SequenceType(np.int32) ) def comp(x): @jax_computation.jax_computation @@ -119,19 +114,19 @@ def _zero(): def _add(a, b): return a + b - return intrinsics.sequence_reduce(x, _zero(), _add) + return federated_language.sequence_reduce(x, _zero(), _add) self.assertEqual(comp(sequence), sum(range(10))) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_sync_local_cpp_execution_context ) def test_federated_sequence_reduce(self): sequence = list(range(10)) - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.SERVER + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.SERVER ) ) def comp(x): @@ -143,26 +138,26 @@ def _zero(): def _add(a, b): return a + b - @federated_computation.federated_computation( - computation_types.SequenceType(np.int32) + @federated_language.federated_computation( + federated_language.SequenceType(np.int32) ) def _sum(sequence): - return intrinsics.sequence_reduce(sequence, _zero(), _add) + return federated_language.sequence_reduce(sequence, _zero(), _add) - return intrinsics.federated_map(_sum, x) + return federated_language.federated_map(_sum, x) self.assertEqual(comp(sequence), sum(range(10))) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_sync_local_cpp_execution_context ) def test_federated_sum(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) def comp(x): - return intrinsics.federated_sum(x) + return federated_language.federated_sum(x) # TODO: b/27340091 - use a TFF specific error message after converting the # result coming out of the execution stack. @@ -172,16 +167,16 @@ def comp(x): # self.assertEqual(comp([1, 2, 3]), 6) comp([1, 2, 3]) - @context_stack_test_utils.with_context( + @federated_language.framework.with_context( cpp_execution_contexts.create_sync_local_cpp_execution_context ) def test_unweighted_federated_mean(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.CLIENTS) ) def comp(x): - return intrinsics.federated_mean(x) + return federated_language.federated_mean(x) # TODO: b/27340091 - use a TFF specific error message after converting the # result coming out of the execution stack. diff --git a/tensorflow_federated/python/core/environments/jax_frontend/BUILD b/tensorflow_federated/python/core/environments/jax_frontend/BUILD index d3c3bf102c..d54ac04127 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/BUILD +++ b/tensorflow_federated/python/core/environments/jax_frontend/BUILD @@ -28,11 +28,7 @@ py_library( srcs = ["jax_computation.py"], deps = [ ":jax_serialization", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", + "@federated_language//federated_language", ], ) @@ -42,16 +38,14 @@ py_test( srcs = ["jax_computation_test.py"], deps = [ ":jax_computation", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:polymorphic_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) py_library( name = "jax_computation_context", srcs = ["jax_computation_context.py"], - deps = ["//tensorflow_federated/python/core/impl/context_stack:context_base"], + deps = ["@federated_language//federated_language"], ) py_test( @@ -69,16 +63,11 @@ py_library( srcs = ["jax_serialization.py"], deps = [ ":jax_computation_context", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/xla_backend:xla_serialization", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:typed_object", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -88,12 +77,9 @@ py_test( srcs = ["jax_serialization_test.py"], deps = [ ":jax_serialization", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/xla_backend:xla_serialization", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) diff --git a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation.py b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation.py index b0344a2e4d..5b6dbeab37 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation.py +++ b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation.py @@ -16,33 +16,29 @@ from collections.abc import Callable, Sequence from typing import Optional, Union +import federated_language import jax import numpy as np import tree from tensorflow_federated.python.core.environments.jax_frontend import jax_serialization -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis def _contains_dtype( - type_spec: computation_types.Type, + type_spec: federated_language.Type, dtype: Union[type[np.generic], Sequence[type[np.generic]]], ) -> bool: """Returns `True` if `type_spec` contains the `dtype`.""" if not isinstance(dtype, Sequence): dtype = [dtype] - def predicate(type_spec: computation_types.Type) -> bool: + def predicate(type_spec: federated_language.Type) -> bool: return ( - isinstance(type_spec, computation_types.TensorType) + isinstance(type_spec, federated_language.TensorType) and type_spec.dtype.type in dtype ) - return type_analysis.contains(type_spec, predicate) + return federated_language.framework.type_contains(type_spec, predicate) def _to_numpy(value: object) -> object: @@ -70,26 +66,28 @@ def transform_result(result: object) -> object: def _jax_wrapper_fn( fn: Callable[..., object], parameter_type: Optional[ - Union[computation_types.StructType, computation_types.TensorType] + Union[federated_language.StructType, federated_language.TensorType] ], unpack: Optional[bool], name: Optional[str] = None, **kwargs, -) -> computation_impl.ConcreteComputation: +) -> federated_language.framework.ConcreteComputation: """Serializes a Python function containing JAX code as a TFF computation. Args: fn: The Python function containing JAX code to be serialized as a computation containing XLA. - parameter_type: An instance of `computation_types.Type` that represents the + parameter_type: An instance of `federated_language.Type` that represents the TFF type of the computation parameter, or `None` if there's none. - unpack: See `unpack` in `function_utils.wrap_as_zero_or_one_arg_callable`. + unpack: See `unpack` in + `federated_language.framework.wrap_as_zero_or_one_arg_callable`. name: The name for the constructed computation (currently ignored). **kwargs: Unused currently. A placeholder for passing Jax strategy specific parameters. Returns: - An instance of `computation_impl.ConcreteComputation` with the constructed + An instance of `federated_language.framework.ConcreteComputation` with the + constructed computation. """ del unpack, name, kwargs # Unused. @@ -111,18 +109,20 @@ def _jax_wrapper_fn( f' for more information.\nFound: {parameter_type}' ) - context_stack = context_stack_impl.context_stack + context_stack = federated_language.framework.global_context_stack comp_pb, extra_type_spec = jax_serialization.serialize_jax_computation( fn, parameter_type, context_stack ) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=comp_pb, context_stack=context_stack, annotated_type=extra_type_spec, ) -jax_computation = computation_wrapper.ComputationWrapper(_jax_wrapper_fn) +jax_computation = federated_language.framework.ComputationWrapper( + _jax_wrapper_fn +) jax_computation.__doc__ = """Decorates/wraps Python functions containing JAX code as TFF computations. This wrapper can be used in a similar manner to `tff.tensorflow.computation`, diff --git a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_context.py b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_context.py index 74efba2784..4eae5efa6c 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_context.py +++ b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_context.py @@ -13,10 +13,10 @@ # limitations under the License. """The implementation of an experimental JAX computation context.""" -from tensorflow_federated.python.core.impl.context_stack import context_base +import federated_language -class JaxComputationContext(context_base.SyncContext): +class JaxComputationContext(federated_language.framework.SyncContext): """An experimental context for building JAX computations.""" def __init__(self): diff --git a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py index 5dd31cd356..6dcb5f000e 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py +++ b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py @@ -16,14 +16,12 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import jax import ml_dtypes import numpy as np from tensorflow_federated.python.core.environments.jax_frontend import jax_computation -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import polymorphic_computation -from tensorflow_federated.python.core.impl.types import computation_types class ToNumpyTest(parameterized.TestCase): @@ -54,8 +52,10 @@ def test_returns_concrete_computation_with_no_arg(self): def _comp(): return 1 - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType(None, np.int32) + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType(None, np.int32) self.assertEqual(_comp.type_signature, expected_type) def test_returns_concrete_computation_with_one_arg(self): @@ -63,8 +63,10 @@ def test_returns_concrete_computation_with_one_arg(self): def _comp(x): return jax.numpy.add(x, 1) - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType(np.int32, np.int32) + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType(np.int32, np.int32) self.assertEqual(_comp.type_signature, expected_type) def test_returns_concrete_computation_with_two_args(self): @@ -72,9 +74,11 @@ def test_returns_concrete_computation_with_two_args(self): def _comp(x, y): return jax.numpy.add(x, y) - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType( - computation_types.StructWithPythonType( + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType( + federated_language.StructWithPythonType( [('x', np.int32), ('y', np.int32)], collections.OrderedDict ), np.int32, @@ -82,17 +86,20 @@ def _comp(x, y): self.assertEqual(_comp.type_signature, expected_type) def test_returns_concrete_computation_with_correct_arg_order(self): + @jax_computation.jax_computation( - computation_types.TensorType(np.int32, (10,)), np.int32 + federated_language.TensorType(np.int32, (10,)), np.int32 ) def _comp(y, x): return jax.numpy.add(x, jax.numpy.sum(y)) - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType( - computation_types.StructWithPythonType( + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType( + federated_language.StructWithPythonType( [ - ('y', computation_types.TensorType(np.int32, (10,))), + ('y', federated_language.TensorType(np.int32, (10,))), ('x', np.int32), ], collections.OrderedDict, @@ -102,34 +109,36 @@ def _comp(y, x): self.assertEqual(_comp.type_signature, expected_type) @parameterized.named_parameters( - ('bool', computation_types.TensorType(np.bool_)), - ('int8', computation_types.TensorType(np.int8)), - ('int16', computation_types.TensorType(np.int16)), - ('int32', computation_types.TensorType(np.int32)), - ('uint8', computation_types.TensorType(np.uint8)), - ('uint16', computation_types.TensorType(np.uint16)), - ('uint32', computation_types.TensorType(np.uint32)), - ('float16', computation_types.TensorType(np.float16)), - ('float32', computation_types.TensorType(np.float32)), - ('complex64', computation_types.TensorType(np.complex64)), - ('bfloat16', computation_types.TensorType(ml_dtypes.bfloat16)), - ('generic', computation_types.TensorType(np.int32)), - ('array', computation_types.TensorType(np.int32, shape=[3])), + ('bool', federated_language.TensorType(np.bool_)), + ('int8', federated_language.TensorType(np.int8)), + ('int16', federated_language.TensorType(np.int16)), + ('int32', federated_language.TensorType(np.int32)), + ('uint8', federated_language.TensorType(np.uint8)), + ('uint16', federated_language.TensorType(np.uint16)), + ('uint32', federated_language.TensorType(np.uint32)), + ('float16', federated_language.TensorType(np.float16)), + ('float32', federated_language.TensorType(np.float32)), + ('complex64', federated_language.TensorType(np.complex64)), + ('bfloat16', federated_language.TensorType(ml_dtypes.bfloat16)), + ('generic', federated_language.TensorType(np.int32)), + ('array', federated_language.TensorType(np.int32, shape=[3])), ) def test_returns_concrete_computation_with_dtype(self, type_spec): @jax_computation.jax_computation(type_spec) def _comp(x): return x - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType(type_spec, type_spec) + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType(type_spec, type_spec) self.assertEqual(_comp.type_signature, expected_type) @parameterized.named_parameters( - ('int64', computation_types.TensorType(np.int64)), - ('uint64', computation_types.TensorType(np.uint64)), - ('float64', computation_types.TensorType(np.float64)), - ('complex128', computation_types.TensorType(np.complex128)), + ('int64', federated_language.TensorType(np.int64)), + ('uint64', federated_language.TensorType(np.uint64)), + ('float64', federated_language.TensorType(np.float64)), + ('complex128', federated_language.TensorType(np.complex128)), ) def test_returns_concrete_computation_with_dtype_and_enable_x64( self, type_spec @@ -140,17 +149,19 @@ def test_returns_concrete_computation_with_dtype_and_enable_x64( def _comp(x): return x - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType(type_spec, type_spec) + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType(type_spec, type_spec) self.assertEqual(_comp.type_signature, expected_type) jax.config.update('jax_enable_x64', False) @parameterized.named_parameters( - ('int64', computation_types.TensorType(np.int64)), - ('uint64', computation_types.TensorType(np.uint64)), - ('float64', computation_types.TensorType(np.float64)), - ('complex128', computation_types.TensorType(np.complex128)), - ('str', computation_types.TensorType(np.str_)), + ('int64', federated_language.TensorType(np.int64)), + ('uint64', federated_language.TensorType(np.uint64)), + ('float64', federated_language.TensorType(np.float64)), + ('complex128', federated_language.TensorType(np.complex128)), + ('str', federated_language.TensorType(np.str_)), ) async def test_raises_raises_value_error_with_dtype(self, type_spec): with self.assertRaises(ValueError): @@ -164,7 +175,9 @@ def test_returns_polymorphic_computation(self): def _comp(x): return jax.numpy.add(x, 1) - self.assertIsInstance(_comp, polymorphic_computation.PolymorphicComputation) + self.assertIsInstance( + _comp, federated_language.framework.PolymorphicComputation + ) if __name__ == '__main__': diff --git a/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization.py b/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization.py index eb370219ec..803dd58cac 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization.py +++ b/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization.py @@ -17,35 +17,32 @@ from collections.abc import Callable, Sequence from typing import Optional, Union +import federated_language +from federated_language.proto import computation_pb2 as pb import jax import numpy as np -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.jax_frontend import jax_computation_context from tensorflow_federated.python.core.environments.xla_backend import xla_serialization -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import typed_object -class _XlaSerializerTensorArg(jax.ShapeDtypeStruct, typed_object.TypedObject): +class _XlaSerializerTensorArg( + jax.ShapeDtypeStruct, federated_language.TypedObject +): """Represents tensor type info understood by both TFF and JAX serializer.""" def __init__( - self, tensor_type: computation_types.TensorType, tensor_index: int + self, tensor_type: federated_language.TensorType, tensor_index: int ): - py_typecheck.check_type(tensor_type, computation_types.TensorType) + py_typecheck.check_type(tensor_type, federated_language.TensorType) jax.ShapeDtypeStruct.__init__(self, tensor_type.shape, tensor_type.dtype) self._type_signature = tensor_type self._tensor_index = tensor_index @property - def type_signature(self) -> computation_types.TensorType: + def type_signature(self) -> federated_language.TensorType: return self._type_signature @property @@ -54,20 +51,20 @@ def tensor_index(self) -> int: @jax.tree_util.register_pytree_node_class -class _XlaSerializerStructArg(structure.Struct, typed_object.TypedObject): +class _XlaSerializerStructArg(structure.Struct, federated_language.TypedObject): """Represents struct type info understood by both TFF and JAX serializer.""" def __init__( self, - type_spec: computation_types.StructType, + type_spec: federated_language.StructType, elements: Sequence[tuple[Optional[str], object]], ): - py_typecheck.check_type(type_spec, computation_types.StructType) + py_typecheck.check_type(type_spec, federated_language.StructType) structure.Struct.__init__(self, elements) self._type_signature = type_spec @property - def type_signature(self) -> computation_types.StructType: + def type_signature(self) -> federated_language.StructType: return self._type_signature def __str__(self) -> str: @@ -77,14 +74,14 @@ def tree_flatten( self, ) -> tuple[ tuple[Union[_XlaSerializerTensorArg, '_XlaSerializerStructArg'], ...], - computation_types.StructType, + federated_language.StructType, ]: return tuple(self), self._type_signature @classmethod def tree_unflatten( cls, - aux_data: computation_types.StructType, + aux_data: federated_language.StructType, children: tuple[ Union[_XlaSerializerTensorArg, '_XlaSerializerStructArg'], ... ], @@ -96,25 +93,27 @@ def tree_unflatten( def _tff_type_to_xla_serializer_arg( - type_spec: computation_types.Type, + type_spec: federated_language.Type, ) -> Union[_XlaSerializerStructArg, _XlaSerializerTensorArg]: """Converts TFF type into an argument for the JAX-to-XLA serializer. Args: - type_spec: An instance of `computation_types.Type` containing only structure - and tensor elements. + type_spec: An instance of `federated_language.Type` containing only + structure and tensor elements. Returns: An object that carries both TFF and JAX type info, to be fed into the JAX serializer. """ - def _undefined_shape_predicate(type_element: computation_types.Type) -> bool: - if not isinstance(type_element, computation_types.TensorType): + def _undefined_shape_predicate(type_element: federated_language.Type) -> bool: + if not isinstance(type_element, federated_language.TensorType): return False - return not array_shape.is_shape_fully_defined(type_element.shape) + return not federated_language.array_shape_is_fully_defined( + type_element.shape + ) - has_undefined_shapes = type_analysis.contains( + has_undefined_shapes = federated_language.framework.type_contains( type_spec, _undefined_shape_predicate ) if has_undefined_shapes: @@ -127,13 +126,13 @@ def _undefined_shape_predicate(type_element: computation_types.Type) -> bool: ) def _make( - type_spec: computation_types.Type, next_unused_tensor_index: int + type_spec: federated_language.Type, next_unused_tensor_index: int ) -> tuple[Union[_XlaSerializerStructArg, _XlaSerializerTensorArg], int]: - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): obj = _XlaSerializerTensorArg(type_spec, next_unused_tensor_index) next_unused_tensor_index = next_unused_tensor_index + 1 return obj, next_unused_tensor_index - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = [] for k, v in structure.to_elements(type_spec): obj, next_unused_tensor_index = _make(v, next_unused_tensor_index) @@ -153,55 +152,59 @@ def _make( def _jax_shape_dtype_struct_to_tff_tensor( val: jax.ShapeDtypeStruct, -) -> computation_types.TensorType: - """Converts `jax.ShapeDtypeStruct` to `computation_types.TensorType`. +) -> federated_language.TensorType: + """Converts `jax.ShapeDtypeStruct` to `federated_language.TensorType`. Args: val: An instance of `jax.ShapeDtypeStruct`. Returns: - A corresponding instance of `computation_types.TensorType`. + A corresponding instance of `federated_language.TensorType`. Raises: TypeError: if arg type mismatches. """ - return computation_types.TensorType(val.dtype, val.shape) + return federated_language.TensorType(val.dtype, val.shape) def serialize_jax_computation( fn: Callable[..., object], parameter_type: Union[ - computation_types.StructType, computation_types.TensorType + federated_language.StructType, federated_language.TensorType ], - context_stack: context_stack_base.ContextStack, -) -> tuple[pb.Computation, computation_types.FunctionType]: + context_stack: federated_language.framework.ContextStack, +) -> tuple[pb.Computation, federated_language.FunctionType]: """Serializes a Python function containing JAX code as a TFF computation. Args: fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. - parameter_type: An instance of `computation_types.Type` that represents the + parameter_type: An instance of `federated_language.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: A 2-tuple of `pb.Computation` with the constructed computation and a - `computation_types.FunctionType` containing the full type including + `federated_language.FunctionType` containing the full type including Python container annotations. Raises: TypeError: if the arguments are of the wrong types. """ - py_typecheck.check_type(context_stack, context_stack_base.ContextStack) + py_typecheck.check_type( + context_stack, federated_language.framework.ContextStack + ) if parameter_type is not None: - parameter_type = computation_types.to_type(parameter_type) + parameter_type = federated_language.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None - args, kwargs = function_utils.unpack_arg(fn, parameter_type, packed_arg) + args, kwargs = federated_language.framework.unpack_arg( + fn, parameter_type, packed_arg + ) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reordered in the actual generated XLA code. @@ -229,7 +232,7 @@ def serialize_jax_computation( ) ) else: - returned_type_spec = computation_types.to_type( + returned_type_spec = federated_language.to_type( jax.tree_util.tree_map( _jax_shape_dtype_struct_to_tff_tensor, lowered.out_info ) @@ -242,13 +245,13 @@ def serialize_jax_computation( if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape) else: - returned_type_spec = computation_types.to_type( + returned_type_spec = federated_language.to_type( jax.tree_util.tree_map( _jax_shape_dtype_struct_to_tff_tensor, returned_shape ) ) - computation_type = computation_types.FunctionType( + computation_type = federated_language.FunctionType( parameter_type, returned_type_spec ) return ( diff --git a/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization_test.py b/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization_test.py index f0e4925978..3cbc85c7d5 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization_test.py +++ b/tensorflow_federated/python/core/environments/jax_frontend/jax_serialization_test.py @@ -15,17 +15,14 @@ import collections from absl.testing import absltest +import federated_language +from federated_language.proto import computation_pb2 as pb import jax import numpy as np -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.jax_frontend import jax_serialization from tensorflow_federated.python.core.environments.xla_backend import xla_serialization -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import type_test_utils class JaxSerializationTest(absltest.TestCase): @@ -36,10 +33,10 @@ def test_serialize_jax_computation_raises_type_error_with_unknown_rank( def fn(x): return x + 10 - parameter_type = computation_types.TensorType(dtype=np.int32, shape=None) + parameter_type = federated_language.TensorType(dtype=np.int32, shape=None) with self.assertRaisesRegex(TypeError, 'fully-defined TensorShapes'): jax_serialization.serialize_jax_computation( - fn, parameter_type, context_stack_impl.context_stack + fn, parameter_type, federated_language.framework.global_context_stack ) def test_serialize_jax_computation_raises_type_error_with_unknown_dimension( @@ -48,10 +45,10 @@ def test_serialize_jax_computation_raises_type_error_with_unknown_dimension( def fn(x): return x + 10 - parameter_type = computation_types.TensorType(dtype=np.int32, shape=[None]) + parameter_type = federated_language.TensorType(dtype=np.int32, shape=[None]) with self.assertRaisesRegex(TypeError, 'fully-defined TensorShapes'): jax_serialization.serialize_jax_computation( - fn, parameter_type, context_stack_impl.context_stack + fn, parameter_type, federated_language.framework.global_context_stack ) def test_serialize_jax_with_noarg_to_int32(self): @@ -60,20 +57,22 @@ def traced_fn(): parameter_type = None comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) @@ -87,22 +86,24 @@ def traced_fn(x): del x return 10 - parameter_type = computation_types.to_type(np.int32) + parameter_type = federated_language.to_type(np.int32) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) @@ -115,22 +116,24 @@ def test_serialize_jax_with_int32_to_int32(self): def traced_fn(x): return x + 10 - parameter_type = computation_types.to_type(np.int32) + parameter_type = federated_language.to_type(np.int32) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) @@ -145,29 +148,31 @@ def traced_fn(x): sum=x['foo'] + x['bar'], difference=x['bar'] - x['foo'] ) - parameter_type = computation_types.to_type( + parameter_type = federated_language.to_type( collections.OrderedDict(foo=np.int32, bar=np.int32) ) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, - result=computation_types.StructType( + result=federated_language.StructType( [('sum', np.int32), ('difference', np.int32)] ), ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, - result=computation_types.StructWithPythonType( + result=federated_language.StructWithPythonType( [('sum', np.int32), ('difference', np.int32)], container_type=collections.OrderedDict, ), @@ -198,24 +203,26 @@ def test_serialize_jax_with_two_args(self): def traced_fn(x, y): return x + y - parameter_type = computation_types.to_type( + parameter_type = federated_language.to_type( collections.OrderedDict(x=np.int32, y=np.int32) ) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) @@ -244,25 +251,27 @@ def test_serialize_jax_with_nested_struct_arg(self): def traced_fn(x, y): return x[0] + y - parameter_type = computation_types.StructType([ - (None, computation_types.StructType([(None, np.int32)])), + parameter_type = federated_language.StructType([ + (None, federated_language.StructType([(None, np.int32)])), (None, np.int32), ]) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) @@ -271,71 +280,81 @@ def test_nested_structure_type_signature_roundtrip(self): def traced_fn(x): return x[0][0] - parameter_type = computation_types.to_type([(np.int32,)]) + parameter_type = federated_language.to_type([(np.int32,)]) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) def test_arg_ordering(self): - parameter_type = computation_types.to_type(( - computation_types.TensorType(np.int32, (10,)), - computation_types.TensorType(np.int32), + parameter_type = federated_language.to_type(( + federated_language.TensorType(np.int32, (10,)), + federated_language.TensorType(np.int32), )) def traced_fn(b, a): return jax.numpy.add(a, jax.numpy.sum(b)) comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - traced_fn, parameter_type, context_stack_impl.context_stack + traced_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType( + federated_language.FunctionType( parameter=parameter_type, result=np.int32 ), ) def test_tracing_with_float64_input(self): self.skipTest('b/237566862') - parameter_type = computation_types.TensorType(np.float64) + parameter_type = federated_language.TensorType(np.float64) identity_fn = lambda x: x comp_pb, annotated_type = jax_serialization.serialize_jax_computation( - identity_fn, parameter_type, context_stack_impl.context_stack + identity_fn, + parameter_type, + federated_language.framework.global_context_stack, ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) - type_test_utils.assert_types_equivalent( + type_spec = federated_language.framework.deserialize_type(comp_pb.type) + federated_language.framework.assert_types_equivalent( type_spec, - computation_types.FunctionType(parameter=np.float64, result=np.float64), + federated_language.FunctionType( + parameter=np.float64, result=np.float64 + ), ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( annotated_type, - computation_types.FunctionType(parameter=np.float64, result=np.float64), + federated_language.FunctionType( + parameter=np.float64, result=np.float64 + ), ) @@ -386,7 +405,7 @@ class XlaSerializerStructArgPytreeTest(absltest.TestCase): def test_named_struct(self): struct_arg = jax_serialization._XlaSerializerStructArg( - computation_types.StructType( + federated_language.StructType( [('a', np.float32), ('b', np.int64), ('c', np.int32)] ), [('a', 1.0), ('b', 2), ('c', 3)], @@ -398,7 +417,7 @@ def test_named_struct(self): def test_unnamed_struct(self): struct_arg = jax_serialization._XlaSerializerStructArg( - computation_types.StructType([ + federated_language.StructType([ (None, np.float32), (None, np.int64), (None, np.int32), @@ -412,7 +431,7 @@ def test_unnamed_struct(self): def test_mixed_named_struct(self): struct_arg = jax_serialization._XlaSerializerStructArg( - computation_types.StructType([ + federated_language.StructType([ ('a', np.int32), (None, np.int32), ('b', np.int64), @@ -427,15 +446,15 @@ def test_mixed_named_struct(self): def test_nested_structs(self): struct_arg = jax_serialization._XlaSerializerStructArg( - computation_types.StructType([ + federated_language.StructType([ ('a', np.int32), ( 'b', - computation_types.StructType([ + federated_language.StructType([ ('c', np.int32), ( 'd', - computation_types.StructType( + federated_language.StructType( [(None, np.int32), (None, np.int32)] ), ), @@ -448,11 +467,11 @@ def test_nested_structs(self): ( 'b', jax_serialization._XlaSerializerStructArg( - computation_types.StructType([ + federated_language.StructType([ ('c', np.int32), ( 'd', - computation_types.StructType( + federated_language.StructType( [(None, np.int32), (None, np.int32)] ), ), @@ -463,7 +482,7 @@ def test_nested_structs(self): ( 'd', jax_serialization._XlaSerializerStructArg( - computation_types.StructType( + federated_language.StructType( [(None, np.int32), (None, np.int32)] ), elements=[(None, 5), (None, 6)], @@ -482,13 +501,13 @@ def test_nested_structs(self): def test_mixed_nested_structs_and_python_containers(self): struct_arg = jax_serialization._XlaSerializerStructArg( - computation_types.StructType([ + federated_language.StructType([ ('a', np.int32), - computation_types.StructType([( + federated_language.StructType([( (None, np.int32), ( None, - computation_types.StructType( + federated_language.StructType( [(None, np.int32), (None, np.int32)] ), ), @@ -502,7 +521,7 @@ def test_mixed_nested_structs_and_python_containers(self): [ 4, jax_serialization._XlaSerializerStructArg( - computation_types.StructType( + federated_language.StructType( [(None, np.int32), (None, np.int32)] ), elements=[(None, 5), (None, 6)], diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD b/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD index 574127dcd0..827f97d699 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD @@ -33,10 +33,9 @@ py_library( ":serialization_utils", ":tensorflow_computation_transformations", ":tensorflow_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -49,8 +48,7 @@ py_test( ":tensorflow_computation_factory", ":tensorflow_computation_test_utils", ":tensorflow_computation_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -118,13 +116,7 @@ py_library( ":type_conversions", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", + "@federated_language//federated_language", ], ) @@ -134,10 +126,7 @@ py_test( deps = [ ":tensorflow_building_block_factory", ":tensorflow_computation_test_utils", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -148,14 +137,10 @@ py_library( ":serialization_utils", ":tensorflow_utils", ":type_conversions", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:type_transformations", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -165,11 +150,9 @@ py_cpu_gpu_test( deps = [ ":tensorflow_computation_factory", ":tensorflow_computation_test_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -179,10 +162,9 @@ py_library( srcs = ["tensorflow_computation_test_utils.py"], deps = [ ":tensorflow_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -191,8 +173,8 @@ py_library( srcs = ["tensorflow_computation_transformations.py"], deps = [ ":serialization_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -205,10 +187,8 @@ py_test( ":tensorflow_computation_factory", ":tensorflow_computation_transformations", ":tensorflow_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -234,9 +214,7 @@ py_test( "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", "//tensorflow_federated/python/core/impl/executors:executor_bindings", "//tensorflow_federated/python/core/impl/executors:value_serialization", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -254,12 +232,7 @@ py_library( ":tensorflow_computation_factory", ":type_conversions", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", + "@federated_language//federated_language", ], ) @@ -268,12 +241,7 @@ py_test( srcs = ["tensorflow_tree_transformations_test.py"], deps = [ ":tensorflow_tree_transformations", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -284,13 +252,11 @@ py_library( ":graph_utils", ":serialization_utils", ":type_conversions", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -302,12 +268,10 @@ py_test( ":serialization_utils", ":tensorflow_test_utils", ":tensorflow_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -318,9 +282,7 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:typed_object", + "@federated_language//federated_language", ], ) @@ -330,9 +292,6 @@ py_test( srcs = ["type_conversions_test.py"], deps = [ ":type_conversions", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:type_test_utils", - "//tensorflow_federated/python/core/impl/types:typed_object", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations.py b/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations.py index bb5a944c57..0b325e2cdd 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations.py @@ -15,15 +15,15 @@ import ctypes -from tensorflow_federated.proto.v0 import computation_pb2 +import federated_language +from federated_language.proto import computation_pb2 + from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import graph_optimizations from tensorflow_federated.python.core.environments.tensorflow_backend import graph_spec from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_transformations from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import transformation_utils def _unpack_proto_into_graph_spec(tf_block_proto): @@ -64,8 +64,8 @@ def optimize_tensorflow_comp(tf_computation, config_proto): """Applies configured optimizations to the graphdef backing a TF comp. Args: - tf_computation: Instance of `building_blocks.CompiledComputation` backed by - TensorFlow. + tf_computation: Instance of + `federated_language.framework.CompiledComputation` backed by TensorFlow. config_proto: Instance of `tf.compat.v1.ConfigProto` specifying the optimizations to apply to the graph backing this TensorFlow computation. @@ -74,7 +74,9 @@ def optimize_tensorflow_comp(tf_computation, config_proto): `tf.compat.v1.GraphDef` backing it run through Grappler with the specified configuration. """ - py_typecheck.check_type(tf_computation, building_blocks.CompiledComputation) + py_typecheck.check_type( + tf_computation, federated_language.framework.CompiledComputation + ) tf_proto = tf_computation.proto graph_spec_obj = _unpack_proto_into_graph_spec(tf_proto) @@ -101,15 +103,16 @@ def optimize_tensorflow_comp(tf_computation, config_proto): optimized_proto = computation_pb2.Computation( type=tf_proto.type, tensorflow=tf_result_proto ) - return building_blocks.CompiledComputation( + return federated_language.framework.CompiledComputation( optimized_proto, type_signature=tf_computation.type_signature ) -class TensorFlowOptimizer(transformation_utils.TransformSpec): - """Applies TF graph optimizations to `building_blocks.CompiledComputation`s. +class TensorFlowOptimizer(federated_language.framework.TransformSpec): + """Applies TF graph optimizations to `federated_language.framework.CompiledComputation`s. - This `transformation_utils.TransformSpec` does not alter the TFF structure of + This `federated_language.framework.TransformSpec` does not alter the TFF + structure of the computations on which it is called; rather, it calls out to TensorFlow libraries which perform optimization on the underlying TensorFlow graph representing local processing. @@ -119,7 +122,7 @@ def __init__(self, config_proto): self._config_proto = config_proto def should_transform(self, comp): - return isinstance(comp, building_blocks.CompiledComputation) + return isinstance(comp, federated_language.framework.CompiledComputation) def transform(self, comp): if not self.should_transform(comp): @@ -130,36 +133,39 @@ def transform(self, comp): def optimize_tensorflow_graphs(comp, grappler_config_proto): """Performs any static optimization on TensorFlow subcomputations.""" transform_spec = TensorFlowOptimizer(grappler_config_proto) - return transformation_utils.transform_postorder( + return federated_language.framework.transform_postorder( comp, transform_spec.transform ) -class DisableCallOpGrappler(transformation_utils.TransformSpec): - """Disables grappler in Call ops in `building_blocks.CompiledComputation`s. +class DisableCallOpGrappler(federated_language.framework.TransformSpec): + """Disables grappler in Call ops in `federated_language.framework.CompiledComputation`s. This overwrites the `config_proto` key of the `NodeDef.attr` field of nodes in a `tf.compat.v1.GraphDef` to ensure that Grappler is disabled at runtime. - This `transformation_utils.TransformSpec` does not alter the TFF structure of + This `federated_language.framework.TransformSpec` does not alter the TFF + structure of the computations on which it is called. """ def should_transform(self, comp): return ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'tensorflow' ) def transform(self, comp): if not self.should_transform(comp): return comp, False - py_typecheck.check_type(comp, building_blocks.CompiledComputation) + py_typecheck.check_type( + comp, federated_language.framework.CompiledComputation + ) new_comp_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( comp.proto ) return ( - building_blocks.CompiledComputation( + federated_language.framework.CompiledComputation( new_comp_proto, type_signature=comp.type_signature ), True, @@ -169,12 +175,12 @@ def transform(self, comp): def transform_tf_call_ops_to_disable_grappler(comp): """Performs grappler disabling on TensorFlow subcomputations.""" transform_spec = DisableCallOpGrappler() - return transformation_utils.transform_postorder( + return federated_language.framework.transform_postorder( comp, transform_spec.transform ) -class VerifyAllowedOps(transformation_utils.TransformSpec): +class VerifyAllowedOps(federated_language.framework.TransformSpec): """Identity transformation that verifies computation contains only allowed ops. This tranverses Tensorflow compiled computations and checks each op is @@ -182,7 +188,8 @@ class VerifyAllowedOps(transformation_utils.TransformSpec): `DisallowedOpInTensorFlowComputationError`. Otherwise if only allowed ops are found, the original computation is returned. - This `transformation_utils.TransformSpec` does not alter the TFF structure of + This `federated_language.framework.TransformSpec` does not alter the TFF + structure of the computations on which it is called. """ @@ -190,19 +197,21 @@ def __init__(self, allowed_op_names: frozenset[str]): self._allowed_op_names = allowed_op_names def should_transform( - self, comp: building_blocks.ComputationBuildingBlock + self, comp: federated_language.framework.ComputationBuildingBlock ) -> bool: return ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'tensorflow' ) def transform( - self, comp: building_blocks.ComputationBuildingBlock - ) -> tuple[building_blocks.ComputationBuildingBlock, bool]: + self, comp: federated_language.framework.ComputationBuildingBlock + ) -> tuple[federated_language.framework.ComputationBuildingBlock, bool]: if not self.should_transform(comp): return comp, False - py_typecheck.check_type(comp, building_blocks.CompiledComputation) + py_typecheck.check_type( + comp, federated_language.framework.CompiledComputation + ) tensorflow_computation_transformations.check_allowed_ops( comp.proto, self._allowed_op_names ) @@ -210,17 +219,17 @@ def transform( def check_allowed_ops( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, allowed_op_names: frozenset[str], -) -> tuple[building_blocks.ComputationBuildingBlock, bool]: +) -> tuple[federated_language.framework.ComputationBuildingBlock, bool]: """Checks any Tensorflow computation contains only allowed ops.""" transform_spec = VerifyAllowedOps(allowed_op_names) - return transformation_utils.transform_postorder( + return federated_language.framework.transform_postorder( comp, transform_spec.transform ) -class RaiseOnDisallowedOp(transformation_utils.TransformSpec): +class RaiseOnDisallowedOp(federated_language.framework.TransformSpec): """Identity transformation that raises an error if a disallowed op is found. This tranverses Tensorflow compiled computations searching for ops that have @@ -228,7 +237,8 @@ class RaiseOnDisallowedOp(transformation_utils.TransformSpec): `DisallowedOpInTensorFlowComputationError`. Otherwise if no disallowed ops are found, the original computation is returned. - This `transformation_utils.TransformSpec` does not alter the TFF structure of + This `federated_language.framework.TransformSpec` does not alter the TFF + structure of the computations on which it is called. """ @@ -236,19 +246,21 @@ def __init__(self, disallowed_op_names: frozenset[str]): self._disallowed_op_names = disallowed_op_names def should_transform( - self, comp: building_blocks.ComputationBuildingBlock + self, comp: federated_language.framework.ComputationBuildingBlock ) -> bool: return ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'tensorflow' ) def transform( - self, comp: building_blocks.ComputationBuildingBlock - ) -> tuple[building_blocks.ComputationBuildingBlock, bool]: + self, comp: federated_language.framework.ComputationBuildingBlock + ) -> tuple[federated_language.framework.ComputationBuildingBlock, bool]: if not self.should_transform(comp): return comp, False - py_typecheck.check_type(comp, building_blocks.CompiledComputation) + py_typecheck.check_type( + comp, federated_language.framework.CompiledComputation + ) tensorflow_computation_transformations.check_no_disallowed_ops( comp.proto, self._disallowed_op_names ) @@ -256,37 +268,40 @@ def transform( def check_disallowed_ops( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, disallowed_op_names: frozenset[str], -) -> tuple[building_blocks.ComputationBuildingBlock, bool]: +) -> tuple[federated_language.framework.ComputationBuildingBlock, bool]: """Raises error on disallowed ops in any Tensorflow computation.""" transform_spec = RaiseOnDisallowedOp(disallowed_op_names) - return transformation_utils.transform_postorder( + return federated_language.framework.transform_postorder( comp, transform_spec.transform ) -class AddUniqueIDs(transformation_utils.TransformSpec): +class AddUniqueIDs(federated_language.framework.TransformSpec): """Populates unique IDs for compiled computations. This overwrites the `tensorlfow.id` field (and in the future other compiled computations) with a unique ID. The IDs produced should be determinstic and reproducible when the transform is applied to the same computation. - This `transformation_utils.TransformSpec` does not alter the TFF structure of + This `federated_language.framework.TransformSpec` does not alter the TFF + structure of the computations on which it is called. """ def should_transform(self, comp): return ( - isinstance(comp, building_blocks.CompiledComputation) + isinstance(comp, federated_language.framework.CompiledComputation) and comp.proto.WhichOneof('computation') == 'tensorflow' ) def transform(self, comp): if not self.should_transform(comp): return comp, False - py_typecheck.check_type(comp, building_blocks.CompiledComputation) + py_typecheck.check_type( + comp, federated_language.framework.CompiledComputation + ) new_tf_proto = computation_pb2.TensorFlow() new_tf_proto.CopyFrom(comp.proto.tensorflow) # Important: we must also serialize the type_signature because TFF might @@ -302,7 +317,7 @@ def transform(self, comp): type=comp.proto.type, tensorflow=new_tf_proto ) return ( - building_blocks.CompiledComputation( + federated_language.framework.CompiledComputation( new_comp_proto, type_signature=comp.type_signature ), True, @@ -312,6 +327,6 @@ def transform(self, comp): def transform_tf_add_ids(comp): """Adds unique IDs to each TensorFlow subcomputations.""" transform_spec = AddUniqueIDs() - return transformation_utils.transform_postorder( + return federated_language.framework.transform_postorder( comp, transform_spec.transform ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations_test.py index 7b8c5bb9f2..af689b512b 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/compiled_computation_transformations_test.py @@ -13,6 +13,7 @@ # limitations under the License. from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf @@ -20,8 +21,6 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_test_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_transformations -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.types import computation_types def _create_compiled_computation(py_fn, parameter_type): @@ -30,7 +29,7 @@ def _create_compiled_computation(py_fn, parameter_type): py_fn, parameter_type ) ) - return building_blocks.CompiledComputation( + return federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) @@ -38,11 +37,11 @@ def _create_compiled_computation(py_fn, parameter_type): class TensorFlowOptimizerTest(absltest.TestCase): def test_should_transform_compiled_computation(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) config = tf.compat.v1.ConfigProto() @@ -52,7 +51,7 @@ def test_should_transform_compiled_computation(self): self.assertTrue(tf_optimizer.should_transform(compiled_computation)) def test_should_not_transform_reference(self): - reference = building_blocks.Reference('x', np.int32) + reference = federated_language.framework.Reference('x', np.int32) config = tf.compat.v1.ConfigProto() tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer( config @@ -60,11 +59,11 @@ def test_should_not_transform_reference(self): self.assertFalse(tf_optimizer.should_transform(reference)) def test_transform_compiled_computation_returns_compiled_computation(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) proto, function_type = tensorflow_computation_factory.create_identity( tuple_type, ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( proto, name=None, type_signature=function_type ) @@ -74,7 +73,9 @@ def test_transform_compiled_computation_returns_compiled_computation(self): ) transformed_comp, mutated = tf_optimizer.transform(compiled_computation) self.assertTrue(mutated) - self.assertIsInstance(transformed_comp, building_blocks.CompiledComputation) + self.assertIsInstance( + transformed_comp, federated_language.framework.CompiledComputation + ) self.assertTrue(transformed_comp.proto.tensorflow.HasField('parameter')) self.assertFalse(transformed_comp.proto.tensorflow.initialize_op) @@ -82,7 +83,7 @@ def test_transform_compiled_computation_returns_compiled_computation_without_emp self, ): proto, type_signature = tensorflow_computation_factory.create_empty_tuple() - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) config = tf.compat.v1.ConfigProto() @@ -91,16 +92,18 @@ def test_transform_compiled_computation_returns_compiled_computation_without_emp ) transformed_comp, mutated = tf_optimizer.transform(compiled_computation) self.assertTrue(mutated) - self.assertIsInstance(transformed_comp, building_blocks.CompiledComputation) + self.assertIsInstance( + transformed_comp, federated_language.framework.CompiledComputation + ) self.assertFalse(transformed_comp.proto.tensorflow.HasField('parameter')) self.assertFalse(transformed_comp.proto.tensorflow.initialize_op) def test_transform_compiled_computation_semantic_equivalence(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) config = tf.compat.v1.ConfigProto() @@ -109,7 +112,9 @@ def test_transform_compiled_computation_semantic_equivalence(self): ) transformed_comp, mutated = tf_optimizer.transform(compiled_computation) self.assertTrue(mutated) - self.assertIsInstance(transformed_comp, building_blocks.CompiledComputation) + self.assertIsInstance( + transformed_comp, federated_language.framework.CompiledComputation + ) zero_before_transform = tensorflow_computation_test_utils.run_tensorflow( compiled_computation.proto, 0 ) @@ -122,11 +127,11 @@ def test_transform_compiled_computation_semantic_equivalence(self): class AddUniqueIDsTest(absltest.TestCase): def test_should_transform_compiled_tf_computation(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) self.assertTrue( @@ -136,7 +141,7 @@ def test_should_transform_compiled_tf_computation(self): ) def test_should_not_transform_non_compiled_computations(self): - reference = building_blocks.Reference('x', np.int32) + reference = federated_language.framework.Reference('x', np.int32) self.assertFalse( compiled_computation_transformations.AddUniqueIDs().should_transform( reference @@ -149,17 +154,17 @@ def test_transform_same_compiled_computation_different_type_signature(self): # should produce a different ID. proto, type_signature = ( tensorflow_computation_factory.create_unary_operator( - lambda x: (), operand_type=computation_types.StructType([]) + lambda x: (), operand_type=federated_language.StructType([]) ) ) - empty_tuple_computation = building_blocks.CompiledComputation( + empty_tuple_computation = federated_language.framework.CompiledComputation( proto, type_signature=type_signature ) add_ids = compiled_computation_transformations.AddUniqueIDs() first_transformed_comp, mutated = add_ids.transform(empty_tuple_computation) self.assertTrue(mutated) self.assertIsInstance( - first_transformed_comp, building_blocks.CompiledComputation + first_transformed_comp, federated_language.framework.CompiledComputation ) self.assertTrue( first_transformed_comp.proto.tensorflow.HasField('cache_key') @@ -169,18 +174,21 @@ def test_transform_same_compiled_computation_different_type_signature(self): # type_signature. proto, type_signature = ( tensorflow_computation_factory.create_unary_operator( - lambda x: ((),), operand_type=computation_types.StructType([]) + lambda x: ((),), operand_type=federated_language.StructType([]) ) ) - nested_empty_tuple_computation = building_blocks.CompiledComputation( - proto, type_signature=type_signature + nested_empty_tuple_computation = ( + federated_language.framework.CompiledComputation( + proto, type_signature=type_signature + ) ) second_transformed_comp, mutated = add_ids.transform( nested_empty_tuple_computation ) self.assertTrue(mutated) self.assertIsInstance( - second_transformed_comp, building_blocks.CompiledComputation + second_transformed_comp, + federated_language.framework.CompiledComputation, ) self.assertTrue( second_transformed_comp.proto.tensorflow.HasField('cache_key') @@ -197,11 +205,11 @@ def test_transform_same_compiled_computation_different_type_signature(self): def test_transform_compiled_computation_returns_compiled_computation_with_id( self, ): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) add_ids = compiled_computation_transformations.AddUniqueIDs() @@ -209,7 +217,8 @@ def test_transform_compiled_computation_returns_compiled_computation_with_id( first_transformed_comp, mutated = add_ids.transform(compiled_computation) self.assertTrue(mutated) self.assertIsInstance( - first_transformed_comp, building_blocks.CompiledComputation + first_transformed_comp, + federated_language.framework.CompiledComputation, ) self.assertTrue( first_transformed_comp.proto.tensorflow.HasField('cache_key') @@ -221,7 +230,8 @@ def test_transform_compiled_computation_returns_compiled_computation_with_id( second_transformed_comp, mutated = add_ids.transform(compiled_computation) self.assertTrue(mutated) self.assertIsInstance( - second_transformed_comp, building_blocks.CompiledComputation + second_transformed_comp, + federated_language.framework.CompiledComputation, ) self.assertTrue( second_transformed_comp.proto.tensorflow.HasField('cache_key') @@ -253,7 +263,7 @@ def test_transform_compiled_computation_returns_compiled_computation_with_id( with self.subTest('different_computation_different_id'): different_compiled_computation = _create_compiled_computation( lambda x: x + np.float32(1.0), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), ) different_transformed_comp, mutated = add_ids.transform( different_compiled_computation @@ -274,11 +284,11 @@ def test_transform_compiled_computation_returns_compiled_computation_with_id( class VerifyAllowedOpsTest(absltest.TestCase): def test_should_transform_tf_computation(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) self.assertTrue( @@ -288,7 +298,7 @@ def test_should_transform_tf_computation(self): ) def test_should_not_transform_non_compiled_computations(self): - reference = building_blocks.Reference('x', np.int32) + reference = federated_language.framework.Reference('x', np.int32) self.assertFalse( compiled_computation_transformations.VerifyAllowedOps( frozenset() @@ -296,11 +306,11 @@ def test_should_not_transform_non_compiled_computations(self): ) def test_transform_only_allowed_ops(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) allowed_op_names = frozenset( @@ -312,11 +322,11 @@ def test_transform_only_allowed_ops(self): self.assertFalse(mutated) def test_transform_disallowed_ops(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) allowed_op_names = frozenset(['Identity']) @@ -331,11 +341,11 @@ def test_transform_disallowed_ops(self): class RaiseOnDisallowedOpTest(absltest.TestCase): def test_should_transform_tf_computation(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) self.assertTrue( @@ -345,7 +355,7 @@ def test_should_transform_tf_computation(self): ) def test_should_not_transform_non_compiled_computations(self): - reference = building_blocks.Reference('x', np.int32) + reference = federated_language.framework.Reference('x', np.int32) self.assertFalse( compiled_computation_transformations.RaiseOnDisallowedOp( frozenset() @@ -353,11 +363,11 @@ def test_should_not_transform_non_compiled_computations(self): ) def test_transform_no_disallowed_ops(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) disallowed_op_names = frozenset(['ShardedFilename']) @@ -367,11 +377,11 @@ def test_transform_no_disallowed_ops(self): self.assertFalse(mutated) def test_transform_disallowed_ops(self): - tuple_type = computation_types.TensorType(np.int32) + tuple_type = federated_language.TensorType(np.int32) compiled_proto, compiled_type = ( tensorflow_computation_factory.create_identity(tuple_type) ) - compiled_computation = building_blocks.CompiledComputation( + compiled_computation = federated_language.framework.CompiledComputation( compiled_proto, name='a', type_signature=compiled_type ) disallowed_op_names = frozenset(['Identity']) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory.py index af756c061e..f3ced338a5 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory.py @@ -17,25 +17,20 @@ import functools from typing import Optional, Union +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis @functools.lru_cache() def _create_tensorflow_constant( - type_spec: computation_types.Type, + type_spec: federated_language.Type, scalar_value: Union[int, float, str], name=None, -) -> building_blocks.Call: +) -> federated_language.framework.Call: """Creates called graph returning constant `scalar_value` of type `type_spec`. `scalar_value` must be a scalar, and cannot be a float if any of the tensor @@ -43,15 +38,15 @@ def _create_tensorflow_constant( only named tuples and tensor types, but these can be arbitrarily nested. Args: - type_spec: A `computation_types.Type` whose resulting type tree can only + type_spec: A `federated_language.Type` whose resulting type tree can only contain named tuples and tensors. scalar_value: Scalar value to place in all the tensor leaves of `type_spec`. name: An optional string name to use as the name of the computation. Returns: - An instance of `building_blocks.Call`, whose argument is `None` + An instance of `federated_language.framework.Call`, whose argument is `None` and whose function is a noarg - `building_blocks.CompiledComputation` which returns the + `federated_language.framework.CompiledComputation` which returns the specified `scalar_value` packed into a TFF structure of type `type_spec. Raises: @@ -60,18 +55,18 @@ def _create_tensorflow_constant( proto, function_type = tensorflow_computation_factory.create_constant( scalar_value, type_spec ) - compiled = building_blocks.CompiledComputation( + compiled = federated_language.framework.CompiledComputation( proto, name, type_signature=function_type ) - return building_blocks.Call(compiled, None) + return federated_language.framework.Call(compiled, None) -def create_null_federated_aggregate() -> building_blocks.Call: +def create_null_federated_aggregate() -> federated_language.framework.Call: """Creates an aggregate over an empty struct and returns an empty struct.""" - unit = building_blocks.Struct([]) + unit = federated_language.framework.Struct([]) unit_type = unit.type_signature - value = building_block_factory.create_federated_value( - unit, placements.CLIENTS + value = federated_language.framework.create_federated_value( + unit, federated_language.CLIENTS ) zero = unit accumulate_proto, accumulate_type = ( @@ -79,73 +74,76 @@ def create_null_federated_aggregate() -> building_blocks.Call: lambda a, b: a, unit_type ) ) - accumulate = building_blocks.CompiledComputation( + accumulate = federated_language.framework.CompiledComputation( accumulate_proto, type_signature=accumulate_type ) merge = accumulate report_proto, report_type = tensorflow_computation_factory.create_identity( - computation_types.StructType([]) + federated_language.StructType([]) ) - report = building_blocks.CompiledComputation( + report = federated_language.framework.CompiledComputation( report_proto, type_signature=report_type ) - return building_block_factory.create_federated_aggregate( + return federated_language.framework.create_federated_aggregate( value, zero, accumulate, merge, report ) def create_null_federated_broadcast(): - return building_block_factory.create_federated_broadcast( - building_block_factory.create_federated_value( - building_blocks.Struct([]), placements.SERVER + return federated_language.framework.create_federated_broadcast( + federated_language.framework.create_federated_value( + federated_language.framework.Struct([]), federated_language.SERVER ) ) -def create_null_federated_map() -> building_blocks.Call: +def create_null_federated_map() -> federated_language.framework.Call: fn_proto, fn_type = tensorflow_computation_factory.create_identity( - computation_types.StructType([]) + federated_language.StructType([]) + ) + fn = federated_language.framework.CompiledComputation( + fn_proto, type_signature=fn_type ) - fn = building_blocks.CompiledComputation(fn_proto, type_signature=fn_type) - return building_block_factory.create_federated_map( + return federated_language.framework.create_federated_map( fn, - building_block_factory.create_federated_value( - building_blocks.Struct([]), placements.CLIENTS + federated_language.framework.create_federated_value( + federated_language.framework.Struct([]), federated_language.CLIENTS ), ) def create_null_federated_secure_sum(): - return building_block_factory.create_federated_secure_sum( - building_block_factory.create_federated_value( - building_blocks.Struct([]), placements.CLIENTS + return federated_language.framework.create_federated_secure_sum( + federated_language.framework.create_federated_value( + federated_language.framework.Struct([]), federated_language.CLIENTS ), - building_blocks.Struct([]), + federated_language.framework.Struct([]), ) def create_null_federated_secure_sum_bitwidth(): - return building_block_factory.create_federated_secure_sum_bitwidth( - building_block_factory.create_federated_value( - building_blocks.Struct([]), placements.CLIENTS + return federated_language.framework.create_federated_secure_sum_bitwidth( + federated_language.framework.create_federated_value( + federated_language.framework.Struct([]), federated_language.CLIENTS ), - building_blocks.Struct([]), + federated_language.framework.Struct([]), ) @functools.lru_cache() def create_generic_constant( - type_spec: Optional[computation_types.Type], scalar_value: Union[int, float] -) -> building_blocks.ComputationBuildingBlock: + type_spec: Optional[federated_language.Type], + scalar_value: Union[int, float], +) -> federated_language.framework.ComputationBuildingBlock: """Creates constant for a combination of federated, tuple and tensor types. Args: - type_spec: A `computation_types.Type` containing only federated, tuple or + type_spec: A `federated_language.Type` containing only federated, tuple or tensor types, or `None` to use to construct a generic constant. scalar_value: The scalar value we wish this constant to have. Returns: - Instance of `building_blocks.ComputationBuildingBlock` + Instance of `federated_language.framework.ComputationBuildingBlock` representing `scalar_value` packed into `type_spec`. Raises: @@ -155,76 +153,82 @@ def create_generic_constant( """ if type_spec is None: return _create_tensorflow_constant(type_spec, scalar_value) - py_typecheck.check_type(type_spec, computation_types.Type) + py_typecheck.check_type(type_spec, federated_language.Type) inferred_scalar_value_type = type_conversions.tensorflow_infer_type( scalar_value ) if not isinstance( - inferred_scalar_value_type, computation_types.TensorType - ) or not array_shape.is_shape_scalar(inferred_scalar_value_type.shape): + inferred_scalar_value_type, federated_language.TensorType + ) or not federated_language.array_shape_is_scalar( + inferred_scalar_value_type.shape + ): raise TypeError( 'Must pass a scalar value to `create_generic_constant`; encountered a ' 'value {}'.format(scalar_value) ) - def _check_parameters(type_spec: computation_types.Type) -> bool: + def _check_parameters(type_spec: federated_language.Type) -> bool: return isinstance( type_spec, ( - computation_types.FederatedType, - computation_types.StructType, - computation_types.TensorType, + federated_language.FederatedType, + federated_language.StructType, + federated_language.TensorType, ), ) - if not type_analysis.contains_only(type_spec, _check_parameters): + if not federated_language.framework.type_contains_only( + type_spec, _check_parameters + ): raise TypeError - def _predicate(type_spec: computation_types.Type) -> bool: + def _predicate(type_spec: federated_language.Type) -> bool: return isinstance( type_spec, ( - computation_types.StructType, - computation_types.TensorType, + federated_language.StructType, + federated_language.TensorType, ), ) - if type_analysis.contains_only(type_spec, _predicate): + if federated_language.framework.type_contains_only(type_spec, _predicate): return _create_tensorflow_constant(type_spec, scalar_value) - elif isinstance(type_spec, computation_types.FederatedType): + elif isinstance(type_spec, federated_language.FederatedType): unplaced_zero = _create_tensorflow_constant(type_spec.member, scalar_value) - if type_spec.placement is placements.CLIENTS: - placement_federated_type = computation_types.FederatedType( + if type_spec.placement is federated_language.CLIENTS: + placement_federated_type = federated_language.FederatedType( type_spec.member, type_spec.placement, all_equal=True ) - placement_fn_type = computation_types.FunctionType( + placement_fn_type = federated_language.FunctionType( type_spec.member, placement_federated_type ) - placement_function = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, placement_fn_type + placement_function = federated_language.framework.Intrinsic( + federated_language.framework.FEDERATED_VALUE_AT_CLIENTS.uri, + placement_fn_type, ) - elif type_spec.placement is placements.SERVER: - placement_federated_type = computation_types.FederatedType( + elif type_spec.placement is federated_language.SERVER: + placement_federated_type = federated_language.FederatedType( type_spec.member, type_spec.placement, all_equal=True ) - placement_fn_type = computation_types.FunctionType( + placement_fn_type = federated_language.FunctionType( type_spec.member, placement_federated_type ) - placement_function = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, placement_fn_type + placement_function = federated_language.framework.Intrinsic( + federated_language.framework.FEDERATED_VALUE_AT_SERVER.uri, + placement_fn_type, ) else: raise NotImplementedError( f'Unexpected placement found: {type_spec.placement}.' ) - return building_blocks.Call(placement_function, unplaced_zero) - elif isinstance(type_spec, computation_types.StructType): + return federated_language.framework.Call(placement_function, unplaced_zero) + elif isinstance(type_spec, federated_language.StructType): elements = [] for i, _ in enumerate(type_spec): elements.append(create_generic_constant(type_spec[i], scalar_value)) names = [name for name, _ in structure.iter_elements(type_spec)] - packed_elements = building_blocks.Struct(elements) - named_tuple = building_block_factory.create_named_tuple( + packed_elements = federated_language.framework.Struct(elements) + named_tuple = federated_language.framework.create_named_tuple( packed_elements, names, type_spec.python_container, @@ -238,9 +242,9 @@ def _predicate(type_spec: computation_types.Type) -> bool: def apply_binary_operator_with_upcast( - arg: building_blocks.ComputationBuildingBlock, + arg: federated_language.framework.ComputationBuildingBlock, operator: Callable[[object, object], object], -) -> building_blocks.Call: +) -> federated_language.framework.Call: """Constructs result of applying `operator` to `arg` upcasting if appropriate. Notice `arg` here must be of federated type, with a named tuple member of @@ -257,25 +261,27 @@ def apply_binary_operator_with_upcast( pointwise. Args: - arg: `building_blocks.ComputationBuildingBlock` of federated type whose - `member` attribute is a named tuple type of length 2, or named tuple type - of length 2. + arg: `federated_language.framework.ComputationBuildingBlock` of federated + type whose `member` attribute is a named tuple type of length 2, or named + tuple type of length 2. operator: Callable representing binary operator to apply to the 2-tuple represented by the federated `arg`. Returns: - Instance of `building_blocks.Call` + Instance of `federated_language.framework.Call` encapsulating the result of formally applying `operator` to `arg[0], `arg[1]`, upcasting `arg[1]` in the condition described above. Raises: TypeError: If the types don't match. """ - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - if isinstance(arg.type_signature, computation_types.FederatedType): + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) + if isinstance(arg.type_signature, federated_language.FederatedType): tuple_type = arg.type_signature.member - assert isinstance(tuple_type, computation_types.StructType) - elif isinstance(arg.type_signature, computation_types.StructType): + assert isinstance(tuple_type, federated_language.StructType) + elif isinstance(arg.type_signature, federated_language.StructType): tuple_type = arg.type_signature else: raise TypeError( @@ -288,15 +294,15 @@ def apply_binary_operator_with_upcast( operator, tuple_type ) ) - tf_representing_op = building_blocks.CompiledComputation( + tf_representing_op = federated_language.framework.CompiledComputation( tf_representing_proto, type_signature=tf_representing_type ) - if isinstance(arg.type_signature, computation_types.FederatedType): - called = building_block_factory.create_federated_map_or_apply( + if isinstance(arg.type_signature, federated_language.FederatedType): + called = federated_language.framework.create_federated_map_or_apply( tf_representing_op, arg ) else: - called = building_blocks.Call(tf_representing_op, arg) + called = federated_language.framework.Call(tf_representing_op, arg) return called diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory_test.py index e5c69017ab..46b05a18b7 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_building_block_factory_test.py @@ -14,14 +14,11 @@ from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements class CreateGenericConstantTest(absltest.TestCase): @@ -35,12 +32,12 @@ def test_raises_non_scalar(self): tensorflow_building_block_factory.create_generic_constant([np.int32], [0]) def test_constructs_tensor_zero(self): - tensor_type = computation_types.TensorType(np.float32, [2, 2]) + tensor_type = federated_language.TensorType(np.float32, [2, 2]) tensor_zero = tensorflow_building_block_factory.create_generic_constant( tensor_type, 0 ) self.assertEqual(tensor_zero.type_signature, tensor_type) - self.assertIsInstance(tensor_zero, building_blocks.Call) + self.assertIsInstance(tensor_zero, federated_language.framework.Call) self.assertTrue( np.array_equal( tensorflow_computation_test_utils.run_tensorflow( @@ -51,13 +48,13 @@ def test_constructs_tensor_zero(self): ) def test_create_unnamed_tuple_zero(self): - tensor_type = computation_types.TensorType(np.float32, [2, 2]) - tuple_type = computation_types.StructType((tensor_type, tensor_type)) + tensor_type = federated_language.TensorType(np.float32, [2, 2]) + tuple_type = federated_language.StructType((tensor_type, tensor_type)) tuple_zero = tensorflow_building_block_factory.create_generic_constant( tuple_type, 0 ) self.assertEqual(tuple_zero.type_signature, tuple_type) - self.assertIsInstance(tuple_zero, building_blocks.Call) + self.assertIsInstance(tuple_zero, federated_language.framework.Call) result = tensorflow_computation_test_utils.run_tensorflow( tuple_zero.function.proto ) @@ -66,8 +63,8 @@ def test_create_unnamed_tuple_zero(self): self.assertTrue(np.array_equal(result[1], np.zeros([2, 2]))) def test_create_named_tuple_one(self): - tensor_type = computation_types.TensorType(np.float32, [2, 2]) - tuple_type = computation_types.StructType( + tensor_type = federated_language.TensorType(np.float32, [2, 2]) + tuple_type = federated_language.StructType( [('a', tensor_type), ('b', tensor_type)] ) @@ -76,7 +73,7 @@ def test_create_named_tuple_one(self): ) self.assertEqual(tuple_zero.type_signature, tuple_type) - self.assertIsInstance(tuple_zero, building_blocks.Call) + self.assertIsInstance(tuple_zero, federated_language.framework.Call) result = tensorflow_computation_test_utils.run_tensorflow( tuple_zero.function.proto ) @@ -85,8 +82,9 @@ def test_create_named_tuple_one(self): self.assertTrue(np.array_equal(result.b, np.ones([2, 2]))) def test_create_federated_tensor_one(self): - fed_type = computation_types.FederatedType( - computation_types.TensorType(np.float32, [2, 2]), placements.CLIENTS + fed_type = federated_language.FederatedType( + federated_language.TensorType(np.float32, [2, 2]), + federated_language.CLIENTS, ) fed_zero = tensorflow_building_block_factory.create_generic_constant( fed_type, 1 @@ -94,12 +92,15 @@ def test_create_federated_tensor_one(self): self.assertEqual(fed_zero.type_signature.member, fed_type.member) self.assertEqual(fed_zero.type_signature.placement, fed_type.placement) self.assertTrue(fed_zero.type_signature.all_equal) - self.assertIsInstance(fed_zero, building_blocks.Call) - self.assertIsInstance(fed_zero.function, building_blocks.Intrinsic) + self.assertIsInstance(fed_zero, federated_language.framework.Call) + self.assertIsInstance( + fed_zero.function, federated_language.framework.Intrinsic + ) self.assertEqual( - fed_zero.function.uri, intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri + fed_zero.function.uri, + federated_language.framework.FEDERATED_VALUE_AT_CLIENTS.uri, ) - self.assertIsInstance(fed_zero.argument, building_blocks.Call) + self.assertIsInstance(fed_zero.argument, federated_language.framework.Call) self.assertTrue( np.array_equal( tensorflow_computation_test_utils.run_tensorflow( @@ -111,22 +112,27 @@ def test_create_federated_tensor_one(self): def test_create_federated_named_tuple_one(self): tuple_type = [ - ('a', computation_types.TensorType(np.float32, [2, 2])), - ('b', computation_types.TensorType(np.float32, [2, 2])), + ('a', federated_language.TensorType(np.float32, [2, 2])), + ('b', federated_language.TensorType(np.float32, [2, 2])), ] - fed_type = computation_types.FederatedType(tuple_type, placements.SERVER) + fed_type = federated_language.FederatedType( + tuple_type, federated_language.SERVER + ) fed_zero = tensorflow_building_block_factory.create_generic_constant( fed_type, 1 ) self.assertEqual(fed_zero.type_signature.member, fed_type.member) self.assertEqual(fed_zero.type_signature.placement, fed_type.placement) self.assertTrue(fed_zero.type_signature.all_equal) - self.assertIsInstance(fed_zero, building_blocks.Call) - self.assertIsInstance(fed_zero.function, building_blocks.Intrinsic) + self.assertIsInstance(fed_zero, federated_language.framework.Call) + self.assertIsInstance( + fed_zero.function, federated_language.framework.Intrinsic + ) self.assertEqual( - fed_zero.function.uri, intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri + fed_zero.function.uri, + federated_language.framework.FEDERATED_VALUE_AT_SERVER.uri, ) - self.assertIsInstance(fed_zero.argument, building_blocks.Call) + self.assertIsInstance(fed_zero.argument, federated_language.framework.Call) result = tensorflow_computation_test_utils.run_tensorflow( fed_zero.argument.function.proto ) @@ -135,12 +141,12 @@ def test_create_federated_named_tuple_one(self): self.assertTrue(np.array_equal(result.b, np.ones([2, 2]))) def test_create_named_tuple_of_federated_tensors_zero(self): - fed_type = computation_types.FederatedType( - computation_types.TensorType(np.float32, [2, 2]), - placements.CLIENTS, + fed_type = federated_language.FederatedType( + federated_language.TensorType(np.float32, [2, 2]), + federated_language.CLIENTS, all_equal=True, ) - tuple_type = computation_types.StructType( + tuple_type = federated_language.StructType( [('a', fed_type), ('b', fed_type)] ) @@ -150,11 +156,14 @@ def test_create_named_tuple_of_federated_tensors_zero(self): fed_zero = zero.argument[0] self.assertEqual(zero.type_signature, tuple_type) - self.assertIsInstance(fed_zero.function, building_blocks.Intrinsic) + self.assertIsInstance( + fed_zero.function, federated_language.framework.Intrinsic + ) self.assertEqual( - fed_zero.function.uri, intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri + fed_zero.function.uri, + federated_language.framework.FEDERATED_VALUE_AT_CLIENTS.uri, ) - self.assertIsInstance(fed_zero.argument, building_blocks.Call) + self.assertIsInstance(fed_zero.argument, federated_language.framework.Call) actual_result = tensorflow_computation_test_utils.run_tensorflow( fed_zero.argument.function.proto ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py index 603aa7d71b..b7869f74e5 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py @@ -17,26 +17,22 @@ import functools from typing import Optional, TypeVar +import federated_language +from federated_language.proto import computation_pb2 import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import type_transformations ComputationProtoAndType = tuple[ - computation_pb2.Computation, computation_types.Type + computation_pb2.Computation, federated_language.Type ] -T = TypeVar('T', bound=computation_types.Type) +T = TypeVar('T', bound=federated_language.Type) class TensorFlowComputationFactory: @@ -46,7 +42,7 @@ def __init__(self): pass def create_constant_from_scalar( - self, value, type_spec: computation_types.Type + self, value, type_spec: federated_language.Type ) -> ComputationProtoAndType: """Creates a TFF computation returning a constant based on a scalar value. @@ -57,11 +53,11 @@ def create_constant_from_scalar( value: A numpy scalar representing the value to return from the constructed computation (or to broadcast to all parts of a nested structure if `type_spec` is a structured type). - type_spec: A `computation_types.Type` of the constructed constant. Must be - either a tensor, or a nested structure of tensors. + type_spec: A `federated_language.Type` of the constructed constant. Must + be either a tensor, or a nested structure of tensors. Returns: - A tuple `(pb.Computation, computation_types.Type)` with the first element + A tuple `(pb.Computation, federated_language.Type)` with the first element being a TFF computation with semantics as described above, and the second element representing the formal type of that computation. """ @@ -72,7 +68,7 @@ def _tensorflow_comp( tensorflow_proto: computation_pb2.TensorFlow, type_signature: T, ) -> tuple[computation_pb2.Computation, T]: - serialized_type = type_serialization.serialize_type(type_signature) + serialized_type = federated_language.framework.serialize_type(type_signature) comp = computation_pb2.Computation( type=serialized_type, tensorflow=tensorflow_proto ) @@ -80,8 +76,8 @@ def _tensorflow_comp( def create_constant( - value, type_spec: computation_types.Type -) -> tuple[computation_pb2.Computation, computation_types.FunctionType]: + value, type_spec: federated_language.Type +) -> tuple[computation_pb2.Computation, federated_language.FunctionType]: """Returns a tensorflow computation returning a constant `value`. The returned computation has the type signature `( -> T)`, where `T` is @@ -93,21 +89,21 @@ def create_constant( Args: value: A value to embed as a constant in the tensorflow graph. - type_spec: A `computation_types.Type` to use as the argument to the + type_spec: A `federated_language.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `type_spec` are violated. """ - if not type_analysis.is_generic_op_compatible_type(type_spec): + if not federated_language.framework.is_generic_op_compatible_type(type_spec): raise TypeError( 'Type spec {} cannot be constructed as a TensorFlow constant in TFF; ' ' only nested tuples and tensors are permitted.'.format(type_spec) ) inferred_value_type = type_conversions.tensorflow_infer_type(value) if isinstance( - inferred_value_type, computation_types.StructType + inferred_value_type, federated_language.StructType ) and not type_spec.is_assignable_from(inferred_value_type): raise TypeError( 'Must pass a only tensor or structure of tensor values to ' @@ -116,21 +112,21 @@ def create_constant( v=value, t=inferred_value_type, s=type_spec ) ) - if isinstance(inferred_value_type, computation_types.StructType): + if isinstance(inferred_value_type, federated_language.StructType): value = structure.from_container(value, recursive=True) tensor_dtypes_in_type_spec = [] def _pack_dtypes(type_signature): """Appends dtype of `type_signature` to nonlocal variable.""" - if isinstance(type_signature, computation_types.TensorType): + if isinstance(type_signature, federated_language.TensorType): tensor_dtypes_in_type_spec.append(type_signature.dtype) return type_signature, False - type_transformations.transform_type_postorder(type_spec, _pack_dtypes) + federated_language.framework.transform_type_postorder(type_spec, _pack_dtypes) if ( any(np.issubdtype(x, np.integer) for x in tensor_dtypes_in_type_spec) - and isinstance(inferred_value_type, computation_types.TensorType) + and isinstance(inferred_value_type, federated_language.TensorType) and not np.issubdtype(inferred_value_type.dtype, np.integer) ): raise TypeError( @@ -143,15 +139,15 @@ def _pack_dtypes(type_signature): def _create_result_tensor(type_spec, value): """Packs `value` into `type_spec` recursively.""" - if isinstance(type_spec, computation_types.TensorType): - if not array_shape.is_shape_fully_defined(type_spec.shape): + if isinstance(type_spec, federated_language.TensorType): + if not federated_language.array_shape_is_fully_defined(type_spec.shape): raise ValueError( f'Expected the shape to be fully defined, found {type_spec.shape}.' ) result = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] - if isinstance(inferred_value_type, computation_types.StructType): + if isinstance(inferred_value_type, federated_language.StructType): # Copy the leaf values according to the type_spec structure. for (name, elem_type), value in zip( structure.iter_elements(type_spec), @@ -171,7 +167,7 @@ def _create_result_tensor(type_spec, value): result, graph ) - type_signature = computation_types.FunctionType(None, result_type) + type_signature = federated_language.FunctionType(None, result_type) tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, @@ -181,7 +177,7 @@ def _create_result_tensor(type_spec, value): def create_unary_operator( - operator: Callable[..., object], operand_type: computation_types.Type + operator: Callable[..., object], operand_type: federated_language.Type ) -> ComputationProtoAndType: """Returns a tensorflow computation computing a unary operation. @@ -192,7 +188,7 @@ def create_unary_operator( Args: operator: A callable taking one argument representing the operation to encode For example: `tf.math.abs`. - operand_type: A `computation_types.Type` to use as the argument to the + operand_type: A `federated_language.Type` to use as the argument to the constructed unary operator; must contain only named tuples and tensor types. @@ -200,12 +196,15 @@ def create_unary_operator( TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ - if operand_type is None or not type_analysis.is_generic_op_compatible_type( - operand_type + if ( + operand_type is None + or not federated_language.framework.is_generic_op_compatible_type( + operand_type + ) ): raise TypeError( '`operand_type` contains a type other than ' - '`computation_types.TensorType` and `computation_types.StructType`; ' + '`federated_language.TensorType` and `federated_language.StructType`; ' f'this is disallowed in the generic operators. Got: {operand_type} ' ) @@ -218,7 +217,7 @@ def create_unary_operator( result_value, graph ) - type_signature = computation_types.FunctionType(operand_type, result_type) + type_signature = federated_language.FunctionType(operand_type, result_type) parameter_binding = operand_binding tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), @@ -230,8 +229,8 @@ def create_unary_operator( def create_binary_operator( operator: Callable[..., object], - operand_type: computation_types.Type, - second_operand_type: Optional[computation_types.Type] = None, + operand_type: federated_language.Type, + second_operand_type: Optional[federated_language.Type] = None, ) -> ComputationProtoAndType: """Returns a tensorflow computation computing a binary operation. @@ -239,7 +238,7 @@ def create_binary_operator( `operand_type` and `U` is the result of applying the `operator` to a tuple of type `` - Note: If `operand_type` is a `computation_types.StructType`, then + Note: If `operand_type` is a `federated_language.StructType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed @@ -250,10 +249,10 @@ def create_binary_operator( operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. - operand_type: A `computation_types.Type` to use as the argument to the + operand_type: A `federated_language.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. - second_operand_type: An optional `computation_types.Type` to use as the + second_operand_type: An optional `federated_language.Type` to use as the seocnd argument to the constructed binary operator. If `None`, operator uses `operand_type` for both arguments. Must contain only named tuples and tensor types. @@ -262,19 +261,23 @@ def create_binary_operator( TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ - if not type_analysis.is_generic_op_compatible_type(operand_type): + if not federated_language.framework.is_generic_op_compatible_type( + operand_type + ): raise TypeError( '`operand_type` contains a type other than ' - '`computation_types.TensorType` and `computation_types.StructType`; ' + '`federated_language.TensorType` and `federated_language.StructType`; ' f'this is disallowed in the generic operators. Got: {operand_type} ' ) if second_operand_type is not None: - if not type_analysis.is_generic_op_compatible_type(second_operand_type): + if not federated_language.framework.is_generic_op_compatible_type( + second_operand_type + ): raise TypeError( - '`second_operand_type` contains a type other than ' - '`computation_types.TensorType` and `computation_types.StructType`; ' - 'this is disallowed in the generic operators. ' - f'Got: {second_operand_type} ' + '`second_operand_type` contains a type other than' + ' `federated_language.TensorType` and' + ' `federated_language.StructType`; this is disallowed in the generic' + f' operators. Got: {second_operand_type} ' ) elif second_operand_type is None: second_operand_type = operand_type @@ -293,8 +296,8 @@ def create_binary_operator( result_value, graph ) - type_signature = computation_types.FunctionType( - computation_types.StructType((operand_type, second_operand_type)), + type_signature = federated_language.FunctionType( + federated_language.StructType((operand_type, second_operand_type)), result_type, ) parameter_binding = computation_pb2.TensorFlow.Binding( @@ -312,13 +315,13 @@ def create_binary_operator( def create_binary_operator_with_upcast( operator: Callable[[object, object], object], - type_signature: computation_types.StructType, + type_signature: federated_language.StructType, ) -> ComputationProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: operator: Callable defining the operator. - type_signature: A `computation_types.StructType` with two elements, both + type_signature: A `federated_language.StructType` with two elements, both only containing structs or tensors in their type tree. The first and second element must match in structure, or the second element may be a single tensor type that is broadcasted (upcast) to the leaves of the @@ -331,10 +334,10 @@ def create_binary_operator_with_upcast( Returns: Same as `create_binary_operator()`. """ - py_typecheck.check_type(type_signature, computation_types.StructType) - type_analysis.check_tensorflow_compatible_type(type_signature) + py_typecheck.check_type(type_signature, federated_language.StructType) + federated_language.framework.check_tensorflow_compatible_type(type_signature) if ( - not isinstance(type_signature, computation_types.StructType) + not isinstance(type_signature, federated_language.StructType) or len(type_signature) != 2 ): raise TypeError( @@ -344,8 +347,8 @@ def create_binary_operator_with_upcast( t=type_signature ) ) - if type_analysis.contains( - type_signature, lambda t: isinstance(t, computation_types.SequenceType) + if federated_language.framework.type_contains( + type_signature, lambda t: isinstance(t, federated_language.SequenceType) ): raise TypeError( 'Applying binary operators in TensorFlow is only ' @@ -353,19 +356,19 @@ def create_binary_operator_with_upcast( 'passed {t} which contains a SequenceType.'.format(t=type_signature) ) - def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type): + def _pack_into_type(to_pack: tf.Tensor, type_spec: federated_language.Type): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" - if isinstance(type_spec, computation_types.StructType): + if isinstance(type_spec, federated_language.StructType): elem_iter = structure.iter_elements(type_spec) return structure.Struct([ (elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter ]) - elif isinstance(type_spec, computation_types.TensorType): + elif isinstance(type_spec, federated_language.TensorType): value_tensor_type = type_conversions.tensorflow_infer_type(to_pack) if type_spec.is_assignable_from(value_tensor_type): return to_pack - elif not array_shape.is_shape_fully_defined(type_spec.shape): + elif not federated_language.array_shape_is_fully_defined(type_spec.shape): raise TypeError( 'Cannot generate TensorFlow creating binary operator ' 'with first type not assignable from second, and ' @@ -385,8 +388,8 @@ def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type): ) if isinstance( - type_signature[0], computation_types.StructType - ) and isinstance(type_signature[1], computation_types.StructType): + type_signature[0], federated_language.StructType + ) and isinstance(type_signature[1], federated_language.StructType): # If both the first and second arguments are structs with the same # structure, simply re-use operand_2_value as. `tf.nest.map_structure` # below will map the binary operator pointwise to the leaves of the @@ -406,9 +409,9 @@ def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type): type_signature[0], # pytype: disable=wrong-arg-types ) - if isinstance(type_signature[0], computation_types.TensorType): + if isinstance(type_signature[0], federated_language.TensorType): result_value = operator(first_arg, second_arg) - elif isinstance(type_signature[0], computation_types.StructType): + elif isinstance(type_signature[0], federated_language.StructType): result_value = structure.map_structure( operator, first_arg, # pytype: disable=wrong-arg-types @@ -424,7 +427,7 @@ def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type): result_value, graph ) - type_signature = computation_types.FunctionType(type_signature, result_type) + type_signature = federated_language.FunctionType(type_signature, result_type) parameter_binding = computation_pb2.TensorFlow.Binding( struct=computation_pb2.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding] @@ -447,24 +450,24 @@ def create_empty_tuple() -> ComputationProtoAndType: def create_identity( - type_signature: computation_types.Type, + type_signature: federated_language.Type, ) -> ComputationProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is - `type_signature`. NOTE: if `T` contains `computation_types.StructType`s + `type_signature`. NOTE: if `T` contains `federated_language.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: - type_signature: A `computation_types.Type` to use as the parameter type and + type_signature: A `federated_language.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ - type_analysis.check_tensorflow_compatible_type(type_signature) + federated_language.framework.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') @@ -473,12 +476,12 @@ def create_identity( if isinstance( type_signature, ( - computation_types.SequenceType, - computation_types.TensorType, + federated_language.SequenceType, + federated_language.TensorType, ), ): identity_fn = tf.identity - elif isinstance(type_signature, computation_types.StructType): + elif isinstance(type_signature, federated_language.StructType): identity_fn = functools.partial(structure.map_structure, tf.identity) else: raise NotImplementedError( @@ -490,7 +493,7 @@ def create_identity( def create_computation_for_py_fn( fn: Callable[..., object], - parameter_type: Optional[computation_types.Type], + parameter_type: Optional[federated_language.Type], ) -> ComputationProtoAndType: """Returns a tensorflow computation returning the result of `fn`. @@ -499,10 +502,10 @@ def create_computation_for_py_fn( Args: fn: A Python function. - parameter_type: A `computation_types.Type` or `None`. + parameter_type: A `federated_language.Type` or `None`. """ if parameter_type is not None: - py_typecheck.check_type(parameter_type, computation_types.Type) + py_typecheck.check_type(parameter_type, federated_language.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: @@ -517,7 +520,7 @@ def create_computation_for_py_fn( result, graph ) - type_signature = computation_types.FunctionType(parameter_type, result_type) + type_signature = federated_language.FunctionType(parameter_type, result_type) tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory_test.py index 45cfb365b8..551c02133b 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory_test.py @@ -16,38 +16,36 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language +from federated_language.proto import computation_pb2 as pb import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_test_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_serialization class CreateConstantTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( - ('scalar_int', 10, computation_types.TensorType(np.int32, [3]), [10] * 3), - ('scalar_float', 10.0, computation_types.TensorType(np.float32, [3]), [10.0] * 3), + ('scalar_int', 10, federated_language.TensorType(np.int32, [3]), [10] * 3), + ('scalar_float', 10.0, federated_language.TensorType(np.float32, [3]), [10.0] * 3), ('scalar_with_unnamed_struct_type', 10, - computation_types.StructType([np.int32] * 3), + federated_language.StructType([np.int32] * 3), structure.Struct([(None, 10)] * 3)), ('scalar_with_named_struct_type', 10, - computation_types.StructType([('a', np.int32), ('b', np.int32), ('c', np.int32)]), + federated_language.StructType([('a', np.int32), ('b', np.int32), ('c', np.int32)]), structure.Struct([('a', 10), ('b', 10), ('c', 10)])), ('scalar_with_nested_struct_type', 10, - computation_types.StructType([[np.int32] * 3] * 3), + federated_language.StructType([[np.int32] * 3] * 3), structure.Struct([(None, structure.Struct([(None, 10)] * 3))] * 3)), ('tuple_with_struct_type', (10, 11, 12), - computation_types.StructType([np.int32, np.int32, np.int32]), + federated_language.StructType([np.int32, np.int32, np.int32]), structure.Struct([(None, 10), (None, 11), (None, 12)])), ('nested_struct_with_nested_struct_type', (10, (11, 12)), - computation_types.StructType([np.int32, [np.int32, np.int32]]), + federated_language.StructType([np.int32, [np.int32, np.int32]]), structure.Struct([ (None, 10), (None, structure.Struct([ @@ -55,7 +53,7 @@ class CreateConstantTest(parameterized.TestCase): ])), ('nested_named_struct_with_nested_struct_type', collections.OrderedDict(a=10, b=collections.OrderedDict(c=11, d=12)), - computation_types.StructType( + federated_language.StructType( collections.OrderedDict(a=np.int32, b=collections.OrderedDict( c=np.int32, d=np.int32))), @@ -65,7 +63,7 @@ class CreateConstantTest(parameterized.TestCase): ('c', 11), ('d', 12)])) ])), ('unnamed_value_named_type', (10.0,), - computation_types.StructType([('a', np.float32)]), + federated_language.StructType([('a', np.float32)]), structure.Struct([('a', 10.0)])), ) # pyformat: enable @@ -75,8 +73,8 @@ def test_returns_computation(self, value, type_signature, expected_result): ) self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = computation_types.FunctionType(None, type_signature) + actual_type = federated_language.framework.deserialize_type(proto.type) + expected_type = federated_language.FunctionType(None, type_signature) expected_type.check_assignable_from(actual_type) actual_result = tensorflow_computation_test_utils.run_tensorflow(proto) if isinstance(expected_result, list): @@ -88,31 +86,31 @@ def test_returns_computation(self, value, type_signature, expected_result): ( 'non_scalar_value', np.zeros([1]), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ('none_type', 10, None), ( 'federated_type', 10, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), ), - ('bad_type', 10.0, computation_types.TensorType(np.int32)), + ('bad_type', 10.0, federated_language.TensorType(np.int32)), ( 'value_structure_larger_than_type_structure', (10.0, 11.0), - computation_types.StructType([np.float32]), + federated_language.StructType([np.float32]), ), ( 'value_structure_smaller_than_type_structure', (10.0,), - computation_types.StructType( + federated_language.StructType( [(None, np.float32), (None, np.float32)] ), ), ( 'named_value_unnamed_type', collections.OrderedDict(a=10.0), - computation_types.StructType([(None, np.float32)]), + federated_language.StructType([(None, np.float32)]), ), ) def test_raises_type_error(self, value, type_signature): @@ -123,20 +121,26 @@ def test_raises_type_error(self, value, type_signature): class CreateUnaryOperatorTest(parameterized.TestCase, tf.test.TestCase): @parameterized.named_parameters( - ('abs_int', tf.math.abs, computation_types.TensorType(np.int32), [-1], 1), + ( + 'abs_int', + tf.math.abs, + federated_language.TensorType(np.int32), + [-1], + 1, + ), ( 'abs_float', tf.math.abs, - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), [-1.0], 1.0, ), ( 'abs_unnamed_tuple', lambda x: structure.map_structure(tf.math.abs, x), - computation_types.StructType([ - computation_types.TensorType(np.int32, [2]), - computation_types.TensorType(np.float32, [2]), + federated_language.StructType([ + federated_language.TensorType(np.int32, [2]), + federated_language.TensorType(np.float32, [2]), ]), [[-1, -2], [-3.0, -4.0]], structure.Struct([(None, [1, 2]), (None, [3.0, 4.0])]), @@ -144,9 +148,9 @@ class CreateUnaryOperatorTest(parameterized.TestCase, tf.test.TestCase): ( 'abs_named_tuple', lambda x: structure.map_structure(tf.math.abs, x), - computation_types.StructType([ - ('a', computation_types.TensorType(np.int32, [2])), - ('b', computation_types.TensorType(np.float32, [2])), + federated_language.StructType([ + ('a', federated_language.TensorType(np.int32, [2])), + ('b', federated_language.TensorType(np.float32, [2])), ]), [[-1, -2], [-3.0, -4.0]], structure.Struct([('a', [1, 2]), ('b', [3.0, 4.0])]), @@ -154,21 +158,21 @@ class CreateUnaryOperatorTest(parameterized.TestCase, tf.test.TestCase): ( 'reduce_sum_int', tf.math.reduce_sum, - computation_types.TensorType(np.int32, [2]), + federated_language.TensorType(np.int32, [2]), [2, 2], 4, ), ( 'reduce_sum_float', tf.math.reduce_sum, - computation_types.TensorType(np.float32, [2]), + federated_language.TensorType(np.float32, [2]), [2.0, 2.5], 4.5, ), ( 'log_inf', tf.math.log, - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), [0.0], -np.inf, ), @@ -182,8 +186,8 @@ def test_returns_computation( ) self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - self.assertIsInstance(actual_type, computation_types.FunctionType) + actual_type = federated_language.framework.deserialize_type(proto.type) + self.assertIsInstance(actual_type, federated_language.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_unary_operator`. @@ -195,14 +199,14 @@ def test_returns_computation( self.assertAllEqual(actual_result, expected_result) @parameterized.named_parameters( - ('non_callable_operator', 1, computation_types.TensorType(np.int32)), + ('non_callable_operator', 1, federated_language.TensorType(np.int32)), ('none_type', tf.math.add, None), ( 'federated_type', tf.math.add, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), ), - ('sequence_type', tf.math.add, computation_types.SequenceType(np.int32)), + ('sequence_type', tf.math.add, federated_language.SequenceType(np.int32)), ) def test_raises_type_error(self, operator, type_signature): with self.assertRaises(TypeError): @@ -216,40 +220,40 @@ class CreateBinaryOperatorTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( ('add_int', tf.math.add, - computation_types.TensorType(np.int32), None, + federated_language.TensorType(np.int32), None, [1, 2], 3), ('add_float', tf.math.add, - computation_types.TensorType(np.float32), None, + federated_language.TensorType(np.float32), None, [1.0, 2.25], 3.25), ('add_unnamed_tuple', lambda x, y: structure.map_structure(tf.math.add, x, y), - computation_types.StructType([np.int32, np.float32]), None, + federated_language.StructType([np.int32, np.float32]), None, [[1, 1.0], [2, 2.25]], structure.Struct([(None, 3), (None, 3.25)])), ('add_named_tuple', lambda x, y: structure.map_structure(tf.math.add, x, y), - computation_types.StructType([('a', np.int32), ('b', np.float32)]), None, + federated_language.StructType([('a', np.int32), ('b', np.float32)]), None, [[1, 1.0], [2, 2.25]], structure.Struct([('a', 3), ('b', 3.25)])), ('multiply_int', tf.math.multiply, - computation_types.TensorType(np.int32), None, + federated_language.TensorType(np.int32), None, [2, 2], 4), ('multiply_float', tf.math.multiply, - computation_types.TensorType(np.float32), None, + federated_language.TensorType(np.float32), None, [2.0, 2.25], 4.5), ('divide_int', tf.math.divide, - computation_types.TensorType(np.int32), None, + federated_language.TensorType(np.int32), None, [4, 2], 2.0), ('divide_float', tf.math.divide, - computation_types.TensorType(np.float32), None, + federated_language.TensorType(np.float32), None, [4.0, 2.0], 2.0), ('divide_inf', tf.math.divide, - computation_types.TensorType(np.int32), None, + federated_language.TensorType(np.int32), None, [1, 0], np.inf), ('different_structure', lambda x, y: structure.map_structure(lambda v: tf.math.divide(v, y), x), - computation_types.StructType([np.float32, np.float32]), - computation_types.TensorType(np.float32), + federated_language.StructType([np.float32, np.float32]), + federated_language.TensorType(np.float32), [[1, 2], 2], structure.Struct([(None, 0.5), (None, 1.0)])), ) @@ -267,17 +271,17 @@ def test_returns_computation( ) self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - self.assertIsInstance(actual_type, computation_types.FunctionType) + actual_type = federated_language.framework.deserialize_type(proto.type) + self.assertIsInstance(actual_type, federated_language.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator`. if second_operand_type is None: - expected_parameter_type = computation_types.StructType( + expected_parameter_type = federated_language.StructType( [operand_type, operand_type] ) else: - expected_parameter_type = computation_types.StructType( + expected_parameter_type = federated_language.StructType( [operand_type, second_operand_type] ) self.assertEqual(actual_type.parameter, expected_parameter_type) @@ -287,14 +291,14 @@ def test_returns_computation( self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( - ('non_callable_operator', 1, computation_types.TensorType(np.int32)), + ('non_callable_operator', 1, federated_language.TensorType(np.int32)), ('none_type', tf.math.add, None), ( 'federated_type', tf.math.add, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), ), - ('sequence_type', tf.math.add, computation_types.SequenceType(np.int32)), + ('sequence_type', tf.math.add, federated_language.SequenceType(np.int32)), ) def test_raises_type_error(self, operator, type_signature): with self.assertRaises(TypeError): @@ -312,80 +316,80 @@ class CreateBinaryOperatorWithUpcastTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( ('add_int_same_shape', tf.math.add, - computation_types.StructType([computation_types.TensorType(np.int32), computation_types.TensorType(np.int32)]), + federated_language.StructType([federated_language.TensorType(np.int32), federated_language.TensorType(np.int32)]), [1, 2], 3), ('add_int_different_shape', tf.math.add, - computation_types.StructType([computation_types.TensorType(np.int64, [1]), computation_types.TensorType(np.int32)]), + federated_language.StructType([federated_language.TensorType(np.int64, [1]), federated_language.TensorType(np.int32)]), [np.array([1]), 2], 3), ('add_int_different_types', tf.math.add, - computation_types.StructType([ - computation_types.StructType([ - computation_types.TensorType(np.int64, [1])]), - computation_types.TensorType(np.int32), + federated_language.StructType([ + federated_language.StructType([ + federated_language.TensorType(np.int64, [1])]), + federated_language.TensorType(np.int32), ]), [[np.array([1])], 2], structure.Struct([(None, 3)])), ('multiply_int_same_shape', tf.math.multiply, - computation_types.StructType([computation_types.TensorType(np.int32), computation_types.TensorType(np.int32)]), + federated_language.StructType([federated_language.TensorType(np.int32), federated_language.TensorType(np.int32)]), [1, 2], 2), ('multiply_int_different_shape', tf.math.multiply, - computation_types.StructType([computation_types.TensorType(np.int64, [1]), computation_types.TensorType(np.int32)]), + federated_language.StructType([federated_language.TensorType(np.int64, [1]), federated_language.TensorType(np.int32)]), [np.array([1]), 2], 2), ('multiply_int_different_types', tf.math.multiply, - computation_types.StructType([ - computation_types.StructType([ - computation_types.TensorType(np.int64, [1])]), - computation_types.TensorType(np.int32) + federated_language.StructType([ + federated_language.StructType([ + federated_language.TensorType(np.int64, [1])]), + federated_language.TensorType(np.int32) ]), [[np.array([1])], 2], structure.Struct([(None, 2)])), ('divide_int_same_shape', tf.math.divide, - computation_types.StructType([computation_types.TensorType(np.int32), computation_types.TensorType(np.int32)]), + federated_language.StructType([federated_language.TensorType(np.int32), federated_language.TensorType(np.int32)]), [1, 2], 0.5), ('divide_int_different_shape', tf.math.divide, - computation_types.StructType([computation_types.TensorType(np.int64, [1]), computation_types.TensorType(np.int32)]), + federated_language.StructType([federated_language.TensorType(np.int64, [1]), federated_language.TensorType(np.int32)]), [np.array([1]), 2], 0.5), ('divide_int_different_types', tf.math.divide, - computation_types.StructType([ - computation_types.StructType([ - computation_types.TensorType(np.int64, [1])]), - computation_types.TensorType(np.int32), + federated_language.StructType([ + federated_language.StructType([ + federated_language.TensorType(np.int64, [1])]), + federated_language.TensorType(np.int32), ]), [[np.array([1])], 2], structure.Struct([(None, 0.5)])), ('divide_int_same_structure', tf.math.divide, - computation_types.StructType([ - computation_types.StructType([ - computation_types.TensorType(np.int64, [1]), - computation_types.TensorType(np.int64, [1]), + federated_language.StructType([ + federated_language.StructType([ + federated_language.TensorType(np.int64, [1]), + federated_language.TensorType(np.int64, [1]), ]), - computation_types.StructType([ - computation_types.TensorType(np.int64), - computation_types.TensorType(np.int64), + federated_language.StructType([ + federated_language.TensorType(np.int64), + federated_language.TensorType(np.int64), ]), ]), [[np.array([1]), np.array([2])], [2, 8]], structure.Struct([(None, 0.5), (None, 0.25)])), ('add_float_unknown_shape', tf.math.add, - computation_types.StructType([ - computation_types.TensorType(np.float64, [None]), - computation_types.TensorType(np.float64, [1]) + federated_language.StructType([ + federated_language.TensorType(np.float64, [None]), + federated_language.TensorType(np.float64, [1]) ]), [np.array([1.0]), np.array([2.25])], np.array([3.25])), ('add_float_unknown_rank', tf.math.add, - computation_types.StructType([ - computation_types.TensorType(np.float64, None), - computation_types.TensorType(np.float64, [1]) + federated_language.StructType([ + federated_language.TensorType(np.float64, None), + federated_language.TensorType(np.float64, [1]) ]), [np.array([1.0]), np.array([2.25])], np.array([3.25])), ('add_float_unknown_shape_inside_struct', tf.math.add, - computation_types.StructType([ - computation_types.StructType([ - computation_types.TensorType(np.float64, [None]) + federated_language.StructType([ + federated_language.StructType([ + federated_language.TensorType(np.float64, [None]) ]), - computation_types.TensorType(np.float64, [1]) + federated_language.TensorType(np.float64, [1]) ]), [[np.array([1.0])], np.array([2.25])], structure.Struct.unnamed([np.array([3.25])])), @@ -401,12 +405,12 @@ def test_returns_computation( ) self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - self.assertIsInstance(actual_type, computation_types.FunctionType) + actual_type = federated_language.framework.deserialize_type(proto.type) + self.assertIsInstance(actual_type, federated_language.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator_with_upcast`. - expected_parameter_type = computation_types.StructType(type_signature) + expected_parameter_type = federated_language.StructType(type_signature) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = tensorflow_computation_test_utils.run_tensorflow( proto, operands @@ -417,22 +421,22 @@ def test_returns_computation( ( 'different_structures', tf.math.add, - computation_types.StructType([ - computation_types.StructType([ - computation_types.TensorType(np.int32), + federated_language.StructType([ + federated_language.StructType([ + federated_language.TensorType(np.int32), ]), - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), + federated_language.StructType([ + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), ]), ]), ), ( 'shape_incompatible', tf.math.add, - computation_types.StructType([ - computation_types.TensorType(np.float64, [None]), - computation_types.TensorType(np.float64, [1, 1]), + federated_language.StructType([ + federated_language.TensorType(np.float64, [None]), + federated_language.TensorType(np.float64, [1, 1]), ]), ), ) @@ -449,8 +453,8 @@ def test_returns_computation(self): proto, _ = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = computation_types.FunctionType(None, []) + actual_type = federated_language.framework.deserialize_type(proto.type) + expected_type = federated_language.FunctionType(None, []) expected_type.check_assignable_from(actual_type) actual_result = tensorflow_computation_test_utils.run_tensorflow(proto) expected_result = structure.Struct([]) @@ -461,22 +465,22 @@ class CreateIdentityTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( - ('int', computation_types.TensorType(np.int32), 10), + ('int', federated_language.TensorType(np.int32), 10), ('unnamed_tuple', - computation_types.StructType([np.int32, np.float32]), + federated_language.StructType([np.int32, np.float32]), structure.Struct([(None, 10), (None, 10.0)])), ('named_tuple', - computation_types.StructType([('a', np.int32), ('b', np.float32)]), + federated_language.StructType([('a', np.int32), ('b', np.float32)]), structure.Struct([('a', 10), ('b', 10.0)])), - ('sequence', computation_types.SequenceType(np.int32), [10] * 3), + ('sequence', federated_language.SequenceType(np.int32), [10] * 3), ) # pyformat: enable def test_returns_computation(self, type_signature, value): proto, _ = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = computation_types.FunctionType( + actual_type = federated_language.framework.deserialize_type(proto.type) + expected_type = federated_language.FunctionType( type_signature, type_signature ) self.assertEqual(actual_type, expected_type) @@ -489,7 +493,7 @@ def test_returns_computation(self, type_signature, value): ('none', None), ( 'federated_type', - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), ), ) def test_raises_type_error(self, type_signature): @@ -498,7 +502,7 @@ def test_raises_type_error(self, type_signature): def test_feeds_and_fetches_different(self): proto, _ = tensorflow_computation_factory.create_identity( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ) self.assertNotEqual(proto.tensorflow.parameter, proto.tensorflow.result) @@ -508,15 +512,15 @@ class CreateComputationForPyFnTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( ('const', lambda: 10, None, None, 10), - ('identity', lambda x: x, computation_types.TensorType(np.int32), 10, 10), + ('identity', lambda x: x, federated_language.TensorType(np.int32), 10, 10), ('add_one', lambda x: x + 1, - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), 10, 11), ('dataset_reduce', lambda ds: ds.reduce(np.int32(0), lambda x, y: x + y), - computation_types.SequenceType(np.int32), + federated_language.SequenceType(np.int32), list(range(10)), 45), ) @@ -533,9 +537,9 @@ def test_returns_computation( self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( - ('none_py_fn', None, computation_types.TensorType(np.int32)), + ('none_py_fn', None, federated_language.TensorType(np.int32)), ('none_type', lambda x: x, None), - ('unnecessary_type', lambda: 10, computation_types.TensorType(np.int32)), + ('unnecessary_type', lambda: 10, federated_language.TensorType(np.int32)), ) def test_raises_type_error_with_none(self, py_fn, type_signature): with self.assertRaises(TypeError): diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_test_utils.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_test_utils.py index b41886ffa0..813f775284 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_test_utils.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_test_utils.py @@ -15,19 +15,18 @@ from typing import Optional +import federated_language +from federated_language.proto import computation_pb2 as pb import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization def _stamp_value_into_graph( value: Optional[object], - type_signature: computation_types.Type, + type_signature: federated_language.Type, graph: tf.Graph, ) -> object: """Stamps `value` in `graph` as an object of type `type_signature`. @@ -44,9 +43,9 @@ def _stamp_value_into_graph( """ if value is None: return None - if isinstance(type_signature, computation_types.TensorType): + if isinstance(type_signature, federated_language.TensorType): if isinstance(value, np.ndarray) or tf.is_tensor(value): - value_type = computation_types.TensorType(value.dtype, value.shape) + value_type = federated_language.TensorType(value.dtype, value.shape) type_signature.check_assignable_from(value_type) with graph.as_default(): return tf.constant(value) @@ -57,7 +56,7 @@ def _stamp_value_into_graph( dtype=type_signature.dtype, # pytype: disable=attribute-error shape=type_signature.shape, # pytype: disable=attribute-error ) - elif isinstance(type_signature, computation_types.StructType): + elif isinstance(type_signature, federated_language.StructType): if isinstance(value, (list, dict)): value = structure.from_container(value) stamped_elements = [] @@ -66,7 +65,7 @@ def _stamp_value_into_graph( stamped_element = _stamp_value_into_graph(element, type_signature, graph) stamped_elements.append((name, stamped_element)) return structure.Struct(stamped_elements) - elif isinstance(type_signature, computation_types.SequenceType): + elif isinstance(type_signature, federated_language.SequenceType): return tensorflow_utils.make_data_set_from_elements( graph, value, type_signature.element ) @@ -90,7 +89,9 @@ def run_tensorflow( The result of the computation. """ with tf.Graph().as_default() as graph: - type_signature = type_serialization.deserialize_type(computation_proto.type) + type_signature = federated_language.framework.deserialize_type( + computation_proto.type + ) if type_signature.parameter is not None: # pytype: disable=attribute-error stamped_arg = _stamp_value_into_graph( arg, diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations.py index 785ebac0fc..8e998a8ca8 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations.py @@ -16,9 +16,9 @@ import itertools from typing import Optional +from federated_language.proto import computation_pb2 import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations_test.py index e55df8a5eb..0bed1ecad2 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_transformations_test.py @@ -16,17 +16,15 @@ import itertools from absl.testing import absltest +import federated_language +from federated_language.proto import computation_pb2 import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_transformations from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization def _extract_call_ops( @@ -65,11 +63,11 @@ def test_raises_on_none(self): ) def test_raises_on_compiled_computation(self): - tensor_type = computation_types.TensorType(np.int32) + tensor_type = federated_language.TensorType(np.int32) comp_proto, comp_type = tensorflow_computation_factory.create_identity( tensor_type ) - comp = building_blocks.CompiledComputation( + comp = federated_language.framework.CompiledComputation( comp_proto, type_signature=comp_type ) with self.assertRaises(TypeError): @@ -95,8 +93,10 @@ def test(): test(), graph ) - function_type = computation_types.FunctionType(None, result_type) - serialized_function_type = type_serialization.serialize_type(function_type) + function_type = federated_language.FunctionType(None, result_type) + serialized_function_type = federated_language.framework.serialize_type( + function_type + ) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( @@ -124,8 +124,10 @@ def test(): test(), graph ) - function_type = computation_types.FunctionType(None, result_type) - serialized_function_type = type_serialization.serialize_type(function_type) + function_type = federated_language.FunctionType(None, result_type) + serialized_function_type = federated_language.framework.serialize_type( + function_type + ) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( @@ -154,8 +156,10 @@ def test(): test(), graph ) - function_type = computation_types.FunctionType(None, result_type) - serialized_function_type = type_serialization.serialize_type(function_type) + function_type = federated_language.FunctionType(None, result_type) + serialized_function_type = federated_language.framework.serialize_type( + function_type + ) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( @@ -180,8 +184,10 @@ def test(): test(), graph ) - function_type = computation_types.FunctionType(None, result_type) - serialized_function_type = type_serialization.serialize_type(function_type) + function_type = federated_language.FunctionType(None, result_type) + serialized_function_type = federated_language.framework.serialize_type( + function_type + ) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( @@ -212,8 +218,10 @@ def test(): test(), graph ) - function_type = computation_types.FunctionType(None, result_type) - serialized_function_type = type_serialization.serialize_type(function_type) + function_type = federated_language.FunctionType(None, result_type) + serialized_function_type = federated_language.framework.serialize_type( + function_type + ) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( @@ -238,8 +246,10 @@ def test(): test(), graph ) - function_type = computation_types.FunctionType(None, result_type) - serialized_function_type = type_serialization.serialize_type(function_type) + function_type = federated_language.FunctionType(None, result_type) + serialized_function_type = federated_language.framework.serialize_type( + function_type + ) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_executor_bindings_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_executor_bindings_test.py index b7c17a610f..af625fa94a 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_executor_bindings_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_executor_bindings_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -25,9 +26,6 @@ from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types from tensorflow_federated.python.core.impl.executors import executor_bindings from tensorflow_federated.python.core.impl.executors import value_serialization -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import type_test_utils # Creating logical devices should be done only once before TF runtime startup @@ -75,7 +73,7 @@ def test_construction( def test_create_value(self): executor = get_executor() # 1. Test a simple tensor. - expected_type_spec = computation_types.TensorType(np.int64, [3]) + expected_type_spec = federated_language.TensorType(np.int64, [3]) value_pb, _ = value_serialization.serialize_value( [1, 2, 3], expected_type_spec ) @@ -88,12 +86,14 @@ def test_create_value(self): deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value ) - type_test_utils.assert_types_identical(type_spec, expected_type_spec) + federated_language.framework.assert_types_identical( + type_spec, expected_type_spec + ) self.assertAllEqual(deserialized_value, [1, 2, 3]) # 2. Test a struct of tensors, ensure that we get a different ID. - expected_type_spec = computation_types.StructType([ - ('a', computation_types.TensorType(np.int64, [3])), - ('b', computation_types.TensorType(np.float32, [])), + expected_type_spec = federated_language.StructType([ + ('a', federated_language.TensorType(np.int64, [3])), + ('b', federated_language.TensorType(np.float32, [])), ]) value = collections.OrderedDict( a=np.array([1, 2, 3], np.int64), @@ -113,7 +113,7 @@ def test_create_value(self): # Note: here we've lost the names `a` and `b` in the output. The output # is a more _strict_ type. self.assertTrue(expected_type_spec.is_assignable_from(type_spec)) - deserialized_value = type_conversions.type_to_py_container( + deserialized_value = federated_language.framework.type_to_py_container( deserialized_value, expected_type_spec ) self.assertAllClose( @@ -123,8 +123,8 @@ def test_create_value(self): # 3. Test creating a value from a computation. foo, _ = tensorflow_computation_factory.create_binary_operator( tf.add, - computation_types.TensorType(np.int64), - computation_types.TensorType(np.int64), + federated_language.TensorType(np.int64), + federated_language.TensorType(np.int64), ) value_pb = executor_pb2.Value(computation=foo) @@ -163,7 +163,7 @@ def test_create_value_sequence_with_reduce_sum( sequence = list(dataset.as_numpy_iterator()) executor = tensorflow_executor_bindings.create_tensorflow_executor() element_type = tensorflow_types.to_type(dataset.element_spec) - sequence_type = computation_types.SequenceType(element_type) + sequence_type = federated_language.SequenceType(element_type) arg_value_pb, _ = value_serialization.serialize_value( sequence, sequence_type ) @@ -184,19 +184,19 @@ def sum_examples(ds): result = executor.create_call(comp.ref, arg.ref) output_pb = executor.materialize(result.ref) result, result_type_spec = value_serialization.deserialize_value(output_pb) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( result_type_spec, - computation_types.TensorType(sequence_type.element.dtype), + federated_language.TensorType(sequence_type.element.dtype), ) self.assertEqual(result, expected_result) def test_create_tuple_of_value_sequence(self): sequences = ([0, 1, 2, 3, 4], [0, 1, 2, 3, 4]) executor = tensorflow_executor_bindings.create_tensorflow_executor() - element_type = computation_types.TensorType(np.int32) - struct_of_sequence_type = computation_types.StructType([ - computation_types.SequenceType(element_type), - computation_types.SequenceType(element_type), + element_type = federated_language.TensorType(np.int32) + struct_of_sequence_type = federated_language.StructType([ + federated_language.SequenceType(element_type), + federated_language.SequenceType(element_type), ]) arg_value_pb, _ = value_serialization.serialize_value( sequences, struct_of_sequence_type @@ -224,13 +224,13 @@ def add_preprocessing(ds1, ds2): _, result_type_spec = value_serialization.deserialize_value( output_pb, type_hint=struct_of_sequence_type ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( result_type_spec, struct_of_sequence_type ) def test_create_struct(self): executor = get_executor() - expected_type_spec = computation_types.TensorType(np.int64, [3]) + expected_type_spec = federated_language.TensorType(np.int64, [3]) value_pb, _ = value_serialization.serialize_value( np.array([1, 2, 3], np.int64), expected_type_spec ) @@ -243,11 +243,13 @@ def test_create_struct(self): deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value ) - struct_type_spec = computation_types.to_type( + struct_type_spec = federated_language.to_type( [expected_type_spec, expected_type_spec] ) - type_test_utils.assert_types_equivalent(type_spec, struct_type_spec) - deserialized_value = type_conversions.type_to_py_container( + federated_language.framework.assert_types_equivalent( + type_spec, struct_type_spec + ) + deserialized_value = federated_language.framework.type_to_py_container( deserialized_value, struct_type_spec ) self.assertAllClose([(1, 2, 3), (1, 2, 3)], deserialized_value) @@ -257,18 +259,20 @@ def test_create_struct(self): deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value ) - struct_type_spec = computation_types.to_type( + struct_type_spec = federated_language.to_type( [struct_type_spec, expected_type_spec] ) - type_test_utils.assert_types_equivalent(type_spec, struct_type_spec) - deserialized_value = type_conversions.type_to_py_container( + federated_language.framework.assert_types_equivalent( + type_spec, struct_type_spec + ) + deserialized_value = federated_language.framework.type_to_py_container( deserialized_value, struct_type_spec ) self.assertAllClose([[(1, 2, 3), (1, 2, 3)], (1, 2, 3)], deserialized_value) def test_create_selection(self): executor = get_executor() - expected_type_spec = computation_types.TensorType(np.int64, [3]) + expected_type_spec = federated_language.TensorType(np.int64, [3]) value_pb, _ = value_serialization.serialize_value( np.array([1, 2, 3], np.int64), expected_type_spec ) @@ -281,11 +285,13 @@ def test_create_selection(self): deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value ) - struct_type_spec = computation_types.to_type( + struct_type_spec = federated_language.to_type( [expected_type_spec, expected_type_spec] ) - type_test_utils.assert_types_equivalent(type_spec, struct_type_spec) - deserialized_value = type_conversions.type_to_py_container( + federated_language.framework.assert_types_equivalent( + type_spec, struct_type_spec + ) + deserialized_value = federated_language.framework.type_to_py_container( deserialized_value, struct_type_spec ) self.assertAllClose([(1, 2, 3), (1, 2, 3)], deserialized_value) @@ -295,8 +301,10 @@ def test_create_selection(self): deserialized_value, type_spec = value_serialization.deserialize_value( materialized_value ) - type_test_utils.assert_types_equivalent(type_spec, expected_type_spec) - deserialized_value = type_conversions.type_to_py_container( + federated_language.framework.assert_types_equivalent( + type_spec, expected_type_spec + ) + deserialized_value = federated_language.framework.type_to_py_container( deserialized_value, struct_type_spec ) self.assertAllClose((1, 2, 3), deserialized_value) @@ -305,15 +313,15 @@ def test_call_with_arg(self): executor = get_executor() value_pb, _ = value_serialization.serialize_value( np.array([1, 2, 3], np.int64), - computation_types.TensorType(np.int64, [3]), + federated_language.TensorType(np.int64, [3]), ) value_ref = executor.create_value(value_pb) arg = executor.create_struct((value_ref.ref, value_ref.ref)) foo, _ = tensorflow_computation_factory.create_binary_operator( tf.add, - computation_types.TensorType(np.int64), - computation_types.TensorType(np.int64), + federated_language.TensorType(np.int64), + federated_language.TensorType(np.int64), ) comp_pb = executor_pb2.Value(computation=foo) @@ -327,7 +335,7 @@ def test_call_no_arg(self): executor = get_executor() foo, _ = tensorflow_computation_factory.create_constant( - 123.0, computation_types.TensorType(np.float32) + 123.0, federated_language.TensorType(np.float32) ) comp_pb = executor_pb2.Value(computation=foo) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations.py index 967cc3fdee..c91253cb77 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations.py @@ -16,6 +16,7 @@ import collections from collections.abc import Callable +import federated_language import numpy as np import tensorflow as tf @@ -23,62 +24,63 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis def reduce_intrinsic( comp, uri, body_fn: Callable[ - [building_blocks.ComputationBuildingBlock], - building_blocks.ComputationBuildingBlock, + [federated_language.framework.ComputationBuildingBlock], + federated_language.framework.ComputationBuildingBlock, ], ): """Replaces all the intrinsics with the given `uri` with a callable.""" - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) py_typecheck.check_type(uri, str) def _should_transform(comp): - return isinstance(comp, building_blocks.Intrinsic) and comp.uri == uri + return ( + isinstance(comp, federated_language.framework.Intrinsic) + and comp.uri == uri + ) def _transform(comp): if not _should_transform(comp): return comp, False - arg_name = next(building_block_factory.unique_name_generator(comp)) - comp_arg = building_blocks.Reference( + arg_name = next(federated_language.framework.unique_name_generator(comp)) + comp_arg = federated_language.framework.Reference( arg_name, comp.type_signature.parameter ) intrinsic_body = body_fn(comp_arg) - intrinsic_reduced = building_blocks.Lambda( + intrinsic_reduced = federated_language.framework.Lambda( comp_arg.name, comp_arg.type_signature, intrinsic_body ) return intrinsic_reduced, True - return transformation_utils.transform_postorder(comp, _transform) + return federated_language.framework.transform_postorder(comp, _transform) def _apply_generic_op(op, arg): if not ( - isinstance(arg.type_signature, computation_types.FederatedType) - or type_analysis.is_structure_of_tensors(arg.type_signature) + isinstance(arg.type_signature, federated_language.FederatedType) + or federated_language.framework.is_structure_of_tensors( + arg.type_signature + ) ): # If there are federated elements nested in a struct, we need to zip these # together before passing to binary operator constructor. - arg = building_block_factory.create_federated_zip(arg) + arg = federated_language.framework.create_federated_zip(arg) return tensorflow_building_block_factory.apply_binary_operator_with_upcast( arg, op ) def _initial_values( - initial_value_fn: Callable[[computation_types.TensorType], object], - member_type: computation_types.Type, -) -> building_blocks.ComputationBuildingBlock: + initial_value_fn: Callable[[federated_language.TensorType], object], + member_type: federated_language.Type, +) -> federated_language.framework.ComputationBuildingBlock: """Create a nested structure of initial values. Args: @@ -88,30 +90,37 @@ def _initial_values( federated type. Returns: - A building_blocks.ComputationBuildingBlock representing the initial values. + A federated_language.framework.ComputationBuildingBlock representing the + initial values. """ - def _fill(tensor_type: computation_types.TensorType) -> building_blocks.Call: + def _fill( + tensor_type: federated_language.TensorType, + ) -> federated_language.framework.Call: computation_proto, function_type = ( tensorflow_computation_factory.create_constant( initial_value_fn(tensor_type), tensor_type ) ) - compiled = building_blocks.CompiledComputation( + compiled = federated_language.framework.CompiledComputation( computation_proto, type_signature=function_type ) - return building_blocks.Call(compiled) + return federated_language.framework.Call(compiled) def _structify_bb( inner_value: object, - ) -> building_blocks.ComputationBuildingBlock: + ) -> federated_language.framework.ComputationBuildingBlock: if isinstance(inner_value, dict): - return building_blocks.Struct( + return federated_language.framework.Struct( [(k, _structify_bb(v)) for k, v in inner_value.items()] ) if isinstance(inner_value, (tuple, list)): - return building_blocks.Struct([_structify_bb(v) for v in inner_value]) - if not isinstance(inner_value, building_blocks.ComputationBuildingBlock): + return federated_language.framework.Struct( + [_structify_bb(v) for v in inner_value] + ) + if not isinstance( + inner_value, federated_language.framework.ComputationBuildingBlock + ): raise ValueError('Encountered unexpected value: ' + str(inner_value)) return inner_value @@ -123,8 +132,8 @@ def _structify_bb( def _get_intrinsic_reductions() -> dict[ str, Callable[ - [building_blocks.ComputationBuildingBlock], - building_blocks.ComputationBuildingBlock, + [federated_language.framework.ComputationBuildingBlock], + federated_language.framework.ComputationBuildingBlock, ], ]: """Returns map from intrinsic to reducing function. @@ -165,41 +174,55 @@ def _get_intrinsic_reductions() -> dict[ def generic_divide(arg): """Divides two arguments when possible.""" - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) return _apply_generic_op(tf.divide, arg) def generic_multiply(arg): """Multiplies two arguments when possible.""" - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) return _apply_generic_op(tf.multiply, arg) def generic_plus(arg): """Adds two arguments when possible.""" - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) return _apply_generic_op(tf.add, arg) def federated_weighted_mean(arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - w = building_blocks.Selection(arg, index=1) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) + w = federated_language.framework.Selection(arg, index=1) multiplied = generic_multiply(arg) - zip_arg = building_blocks.Struct([(None, multiplied), (None, w)]) - summed = federated_sum(building_block_factory.create_federated_zip(zip_arg)) + zip_arg = federated_language.framework.Struct( + [(None, multiplied), (None, w)] + ) + summed = federated_sum( + federated_language.framework.create_federated_zip(zip_arg) + ) return generic_divide(summed) def federated_mean(arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) one = tensorflow_building_block_factory.create_generic_constant( arg.type_signature, 1 ) - mean_arg = building_blocks.Struct([(None, arg), (None, one)]) + mean_arg = federated_language.framework.Struct([(None, arg), (None, one)]) return federated_weighted_mean(mean_arg) - def federated_min(x: building_blocks.ComputationBuildingBlock): - if not isinstance(x.type_signature, computation_types.FederatedType): + def federated_min(x: federated_language.framework.ComputationBuildingBlock): + if not isinstance(x.type_signature, federated_language.FederatedType): raise TypeError('Expected a federated value.') operand_type = x.type_signature.member - def _max_fn(tensor_type: computation_types.TensorType): + def _max_fn(tensor_type: federated_language.TensorType): if np.issubdtype(tensor_type.dtype, np.integer): return np.iinfo(tensor_type.dtype).max elif np.issubdtype(tensor_type.dtype, np.floating): @@ -213,23 +236,23 @@ def _max_fn(tensor_type: computation_types.TensorType): min_proto, min_type = ( tensorflow_computation_factory.create_binary_operator_with_upcast( tf.minimum, - computation_types.StructType([operand_type, operand_type]), + federated_language.StructType([operand_type, operand_type]), ) ) - min_op = building_blocks.CompiledComputation( + min_op = federated_language.framework.CompiledComputation( min_proto, type_signature=min_type ) - identity = building_block_factory.create_identity(operand_type) - return building_block_factory.create_federated_aggregate( + identity = federated_language.framework.create_identity(operand_type) + return federated_language.framework.create_federated_aggregate( x, zero, min_op, min_op, identity ) - def federated_max(x: building_blocks.ComputationBuildingBlock): - if not isinstance(x.type_signature, computation_types.FederatedType): + def federated_max(x: federated_language.framework.ComputationBuildingBlock): + if not isinstance(x.type_signature, federated_language.FederatedType): raise TypeError('Expected a federated value.') operand_type = x.type_signature.member - def _min_fn(tensor_type: computation_types.TensorType): + def _min_fn(tensor_type: federated_language.TensorType): if np.issubdtype(tensor_type.dtype, np.integer): return np.iinfo(tensor_type.dtype).min elif np.issubdtype(tensor_type.dtype, np.floating): @@ -243,33 +266,35 @@ def _min_fn(tensor_type: computation_types.TensorType): max_proto, max_type = ( tensorflow_computation_factory.create_binary_operator_with_upcast( tf.maximum, - computation_types.StructType([operand_type, operand_type]), + federated_language.StructType([operand_type, operand_type]), ) ) - max_op = building_blocks.CompiledComputation( + max_op = federated_language.framework.CompiledComputation( max_proto, type_signature=max_type ) - identity = building_block_factory.create_identity(operand_type) - return building_block_factory.create_federated_aggregate( + identity = federated_language.framework.create_identity(operand_type) + return federated_language.framework.create_federated_aggregate( x, zero, max_op, max_op, identity ) def federated_sum(x): - py_typecheck.check_type(x, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + x, federated_language.framework.ComputationBuildingBlock + ) operand_type = x.type_signature.member # pytype: disable=attribute-error zero = tensorflow_building_block_factory.create_generic_constant( operand_type, 0 ) plus_proto, plus_type = ( tensorflow_computation_factory.create_binary_operator_with_upcast( - tf.add, computation_types.StructType([operand_type, operand_type]) + tf.add, federated_language.StructType([operand_type, operand_type]) ) ) - plus_op = building_blocks.CompiledComputation( + plus_op = federated_language.framework.CompiledComputation( plus_proto, type_signature=plus_type ) - identity = building_block_factory.create_identity(operand_type) - return building_block_factory.create_federated_aggregate( + identity = federated_language.framework.create_identity(operand_type) + return federated_language.framework.create_federated_aggregate( x, zero, plus_op, plus_op, identity ) @@ -331,14 +356,17 @@ def federated_sum(x): # - SEQUENCE_REDUCE intrinsic_bodies_by_uri = collections.OrderedDict([ - (intrinsic_defs.FEDERATED_MEAN.uri, federated_mean), - (intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, federated_weighted_mean), - (intrinsic_defs.FEDERATED_MIN.uri, federated_min), - (intrinsic_defs.FEDERATED_MAX.uri, federated_max), - (intrinsic_defs.FEDERATED_SUM.uri, federated_sum), - (intrinsic_defs.GENERIC_DIVIDE.uri, generic_divide), - (intrinsic_defs.GENERIC_MULTIPLY.uri, generic_multiply), - (intrinsic_defs.GENERIC_PLUS.uri, generic_plus), + (federated_language.framework.FEDERATED_MEAN.uri, federated_mean), + ( + federated_language.framework.FEDERATED_WEIGHTED_MEAN.uri, + federated_weighted_mean, + ), + (federated_language.framework.FEDERATED_MIN.uri, federated_min), + (federated_language.framework.FEDERATED_MAX.uri, federated_max), + (federated_language.framework.FEDERATED_SUM.uri, federated_sum), + (federated_language.framework.GENERIC_DIVIDE.uri, generic_divide), + (federated_language.framework.GENERIC_MULTIPLY.uri, generic_multiply), + (federated_language.framework.GENERIC_PLUS.uri, generic_plus), ]) return intrinsic_bodies_by_uri @@ -347,7 +375,7 @@ def replace_intrinsics_with_bodies(comp): """Iterates over all intrinsic bodies, inlining the intrinsics in `comp`. This function operates on the AST level; meaning, it takes in a - `building_blocks.ComputationBuildingBlock` as an argument and + `federated_language.framework.ComputationBuildingBlock` as an argument and returns one as well. `replace_intrinsics_with_bodies` is intended to be the standard reduction function, which will reduce all currently implemented intrinsics to their bodies. @@ -357,18 +385,20 @@ def replace_intrinsics_with_bodies(comp): function is ordered from more complex intrinsic to less complex intrinsics. Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` in which we - wish to replace all intrinsics with their bodies. + comp: Instance of `federated_language.framework.ComputationBuildingBlock` in + which we wish to replace all intrinsics with their bodies. Returns: - Instance of `building_blocks.ComputationBuildingBlock` with all + Instance of `federated_language.framework.ComputationBuildingBlock` with all the intrinsics from `intrinsic_bodies.py` inlined with their bodies, along with a Boolean indicating whether there was any inlining in fact done. Raises: TypeError: If the types don't match. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) bodies = _get_intrinsic_reductions() transformed = False for uri, body in bodies.items(): diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations_test.py index 718af8afa4..01c1c9bfb1 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_tree_transformations_test.py @@ -14,26 +14,21 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils def _count_intrinsics(comp, uri): def _predicate(comp): return ( - isinstance(comp, building_blocks.Intrinsic) + isinstance(comp, federated_language.framework.Intrinsic) and uri is not None and comp.uri == uri ) - return tree_analysis.count(comp, _predicate) + return federated_language.framework.computation_count(comp, _predicate) class ReplaceIntrinsicsWithBodiesTest(parameterized.TestCase): @@ -43,13 +38,17 @@ def test_raises_on_none(self): tensorflow_tree_transformations.replace_intrinsics_with_bodies(None) def test_federated_mean_reduces_to_aggregate(self): - uri = intrinsic_defs.FEDERATED_MEAN.uri + uri = federated_language.framework.FEDERATED_MEAN.uri - comp = building_blocks.Intrinsic( + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FunctionType( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), ) @@ -59,10 +58,10 @@ def test_federated_mean_reduces_to_aggregate(self): ) count_means_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( - reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_means_before_reduction, 0) @@ -70,14 +69,20 @@ def test_federated_mean_reduces_to_aggregate(self): self.assertGreater(count_aggregations, 0) def test_federated_weighted_mean_reduces_to_aggregate(self): - uri = intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri + uri = federated_language.framework.FEDERATED_WEIGHTED_MEAN.uri - comp = building_blocks.Intrinsic( + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( - (computation_types.FederatedType(np.float32, placements.CLIENTS),) + federated_language.FunctionType( + ( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + ) * 2, - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), ) @@ -86,11 +91,11 @@ def test_federated_weighted_mean_reduces_to_aggregate(self): tensorflow_tree_transformations.replace_intrinsics_with_bodies(comp) ) count_aggregations = _count_intrinsics( - reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri ) count_means_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_means_before_reduction, 0) @@ -98,13 +103,17 @@ def test_federated_weighted_mean_reduces_to_aggregate(self): self.assertGreater(count_aggregations, 0) def test_federated_min_reduces_to_aggregate(self): - uri = intrinsic_defs.FEDERATED_MIN.uri + uri = federated_language.framework.FEDERATED_MIN.uri - comp = building_blocks.Intrinsic( + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FunctionType( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), ) @@ -114,10 +123,10 @@ def test_federated_min_reduces_to_aggregate(self): ) count_min_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( - reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_min_before_reduction, 0) @@ -125,13 +134,17 @@ def test_federated_min_reduces_to_aggregate(self): self.assertGreater(count_aggregations, 0) def test_federated_max_reduces_to_aggregate(self): - uri = intrinsic_defs.FEDERATED_MAX.uri + uri = federated_language.framework.FEDERATED_MAX.uri - comp = building_blocks.Intrinsic( + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FunctionType( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), ) @@ -141,10 +154,10 @@ def test_federated_max_reduces_to_aggregate(self): ) count_max_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( - reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_max_before_reduction, 0) @@ -152,13 +165,17 @@ def test_federated_max_reduces_to_aggregate(self): self.assertGreater(count_aggregations, 0) def test_federated_sum_reduces_to_aggregate(self): - uri = intrinsic_defs.FEDERATED_SUM.uri + uri = federated_language.framework.FEDERATED_SUM.uri - comp = building_blocks.Intrinsic( + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType( - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FunctionType( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), ) @@ -168,10 +185,10 @@ def test_federated_sum_reduces_to_aggregate(self): ) count_sum_after_reduction = _count_intrinsics(reduced, uri) count_aggregations = _count_intrinsics( - reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri + reduced, federated_language.framework.FEDERATED_AGGREGATE.uri ) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_sum_before_reduction, 0) @@ -179,10 +196,10 @@ def test_federated_sum_reduces_to_aggregate(self): self.assertGreater(count_aggregations, 0) def test_generic_divide_reduces(self): - uri = intrinsic_defs.GENERIC_DIVIDE.uri - comp = building_blocks.Intrinsic( + uri = federated_language.framework.GENERIC_DIVIDE.uri + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType([np.float32, np.float32], np.float32), + federated_language.FunctionType([np.float32, np.float32], np.float32), ) count_before_reduction = _count_intrinsics(comp, uri) @@ -192,18 +209,20 @@ def test_generic_divide_reduces(self): count_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) - tree_analysis.check_contains_only_reducible_intrinsics(reduced) + federated_language.framework.check_contains_only_reducible_intrinsics( + reduced + ) def test_generic_multiply_reduces(self): - uri = intrinsic_defs.GENERIC_MULTIPLY.uri - comp = building_blocks.Intrinsic( + uri = federated_language.framework.GENERIC_MULTIPLY.uri + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType([np.float32, np.float32], np.float32), + federated_language.FunctionType([np.float32, np.float32], np.float32), ) count_before_reduction = _count_intrinsics(comp, uri) @@ -213,18 +232,20 @@ def test_generic_multiply_reduces(self): count_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) - tree_analysis.check_contains_only_reducible_intrinsics(reduced) + federated_language.framework.check_contains_only_reducible_intrinsics( + reduced + ) def test_generic_plus_reduces(self): - uri = intrinsic_defs.GENERIC_PLUS.uri - comp = building_blocks.Intrinsic( + uri = federated_language.framework.GENERIC_PLUS.uri + comp = federated_language.framework.Intrinsic( uri, - computation_types.FunctionType([np.float32, np.float32], np.float32), + federated_language.FunctionType([np.float32, np.float32], np.float32), ) count_before_reduction = _count_intrinsics(comp, uri) @@ -234,12 +255,14 @@ def test_generic_plus_reduces(self): count_after_reduction = _count_intrinsics(reduced, uri) self.assertTrue(modified) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, reduced.type_signature ) self.assertGreater(count_before_reduction, 0) self.assertEqual(count_after_reduction, 0) - tree_analysis.check_contains_only_reducible_intrinsics(reduced) + federated_language.framework.check_contains_only_reducible_intrinsics( + reduced + ) if __name__ == '__main__': diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils.py index 0b6fe998f7..17ffcb8e9b 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils.py @@ -20,19 +20,17 @@ from typing import Optional import attrs +import federated_language +from federated_language.proto import computation_pb2 as pb import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import graph_utils from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_serialization _TENSOR_REPRESENTATION_TYPES = ( # Python native types @@ -62,7 +60,7 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph): a best-effort attempt will be made to make them similar for ease of debugging. parameter_type: The type of the parameter to stamp. Must be either an - instance of computation_types.Type (or convertible to it), or None. + instance of federated_language.Type (or convertible to it), or None. graph: The instance of tf.Graph to stamp in. Returns: @@ -74,7 +72,7 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph): to the tensors and ops stamped into the graph. Raises: - TypeError: If the arguments are of the wrong computation_types. + TypeError: If the arguments are of the wrong federated_language. ValueError: If the parameter type cannot be stamped in a TensorFlow graph. """ py_typecheck.check_type(parameter_name, str) @@ -82,7 +80,7 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph): if parameter_type is None: return (None, None) parameter_type = tensorflow_types.to_type(parameter_type) - if isinstance(parameter_type, computation_types.TensorType): + if isinstance(parameter_type, federated_language.TensorType): with graph.as_default(): placeholder = tf.compat.v1.placeholder( dtype=parameter_type.dtype, @@ -93,10 +91,10 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph): tensor=pb.TensorFlow.TensorBinding(tensor_name=placeholder.name) ) return (placeholder, binding) - elif isinstance(parameter_type, computation_types.StructType): + elif isinstance(parameter_type, federated_language.StructType): # The parameter_type could be a StructTypeWithPyContainer, however, we # ignore that for now. Instead, the proper containers will be inserted at - # call time by function_utils.wrap_as_zero_or_one_arg_callable. + # call time by federated_language.framework.wrap_as_zero_or_one_arg_callable. if not parameter_type: # Stamps whimsy element to "populate" graph, as TensorFlow does not # support empty graphs. @@ -116,7 +114,7 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph): struct=pb.TensorFlow.StructBinding(element=element_bindings) ), ) - elif isinstance(parameter_type, computation_types.SequenceType): + elif isinstance(parameter_type, federated_language.SequenceType): with graph.as_default(): with tf.device('/device:cpu:0'): variant_tensor = tf.compat.v1.placeholder(tf.variant, shape=[]) @@ -194,7 +192,7 @@ class UnsupportedGraphResultError(InvalidGraphResultError): def capture_result_from_graph( result: object, graph: tf.Graph, -) -> tuple[computation_types.Type, pb.TensorFlow.Binding]: +) -> tuple[federated_language.Type, pb.TensorFlow.Binding]: """Captures a result stamped into a tf.Graph as a type signature and binding. Args: @@ -205,7 +203,7 @@ def capture_result_from_graph( Returns: A tuple (type_spec, binding), where 'type_spec' is an instance of - computation_types.Type that describes the type of the result, and 'binding' + federated_language.Type that describes the type of the result, and 'binding' is an instance of TensorFlow.Binding that indicates how parts of the result type relate to the tensors and ops that appear in the result. @@ -222,7 +220,7 @@ def _get_bindings_for_elements( name_value_pairs: Iterable[tuple[str, object]], graph: tf.Graph, container_type: Optional[type[object]], - ) -> tuple[computation_types.Type, pb.TensorFlow.Binding]: + ) -> tuple[federated_language.Type, pb.TensorFlow.Binding]: """Build `(type_spec, binding)` tuple for name value pairs.""" element_name_type_binding_triples = [ ((k,) + capture_result_from_graph(v, graph)) @@ -233,11 +231,11 @@ def _get_bindings_for_elements( type_member = (e[0], e[1]) if e[0] else e[1] type_members.append(type_member) if container_type: - type_spec = computation_types.StructWithPythonType( + type_spec = federated_language.StructWithPythonType( type_members, container_type=container_type ) else: - type_spec = computation_types.StructType(type_members) + type_spec = federated_language.StructType(type_members) binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[e[2] for e in element_name_type_binding_triples] @@ -291,7 +289,7 @@ def _get_bindings_for_elements( else: shape = None return ( - computation_types.TensorType(dtype, shape), + federated_language.TensorType(dtype, shape), pb.TensorFlow.Binding( tensor=pb.TensorFlow.TensorBinding(tensor_name=result.name) ), @@ -325,7 +323,7 @@ def _get_bindings_for_elements( capture_result_from_graph(e, graph) for e in result ] return ( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [e[0] for e in element_type_binding_pairs], type(result) ), pb.TensorFlow.Binding( @@ -354,7 +352,7 @@ def _get_bindings_for_elements( with tf.device('/device:cpu:0'): variant_tensor = tf.identity(tf.data.experimental.to_variant(result)) return ( - computation_types.SequenceType(element_type), + federated_language.SequenceType(element_type), pb.TensorFlow.Binding( sequence=pb.TensorFlow.SequenceBinding( variant_tensor_name=variant_tensor.name @@ -383,7 +381,7 @@ def _compute_map_from_bindings(source, target): tensors in the corresponding parts of `target`. Raises: - TypeError: If the arguments are of the wrong computation_types. + TypeError: If the arguments are of the wrong federated_language. ValueError: If the bindings have mismatching structures. """ py_typecheck.check_type(source, pb.TensorFlow.Binding) @@ -492,7 +490,7 @@ def _assemble_result_from_graph(type_spec, binding, output_map): the type and binding don't match, or the tensor is not found in the map. """ type_spec = tensorflow_types.to_type(type_spec) - py_typecheck.check_type(type_spec, computation_types.Type) + py_typecheck.check_type(type_spec, federated_language.Type) py_typecheck.check_type(binding, pb.TensorFlow.Binding) py_typecheck.check_type(output_map, dict) for k, v in output_map.items(): @@ -505,7 +503,7 @@ def _assemble_result_from_graph(type_spec, binding, output_map): ) binding_oneof = binding.WhichOneof('binding') - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): if binding_oneof != 'tensor': raise ValueError( 'Expected a tensor binding, found {}.'.format(binding_oneof) @@ -528,7 +526,7 @@ def _assemble_result_from_graph(type_spec, binding, output_map): 'Prefer usage of `tf.ensure_shape` to `tf.set_shape`.' ) return tensor - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): if binding_oneof != 'struct': raise ValueError( 'Expected a struct binding, found {}.'.format(binding_oneof) @@ -557,7 +555,7 @@ def _assemble_result_from_graph(type_spec, binding, output_map): ) or attrs.has(container_type): return container_type(**dict(result_elements)) return container_type(result_elements) # pylint: disable=too-many-function-args - elif isinstance(type_spec, computation_types.SequenceType): + elif isinstance(type_spec, federated_language.SequenceType): if binding_oneof != 'sequence': raise ValueError( 'Expected a sequence binding, found {}.'.format(binding_oneof) @@ -598,10 +596,10 @@ def _make_empty_list_structure_for_element_type_spec(type_spec): TypeError: If the `type_spec` is not of a form described above. """ type_spec = tensorflow_types.to_type(type_spec) - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): + py_typecheck.check_type(type_spec, federated_language.Type) + if isinstance(type_spec, federated_language.TensorType): return [] - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = structure.to_elements(type_spec) if all(k is not None for k, _ in elements): return collections.OrderedDict([ @@ -645,27 +643,27 @@ def _make_whimsy_element_for_type_spec(type_spec, none_dim_replacement=0): Returns: Returns possibly nested `numpy ndarray`s containing all zeros: a single - `ndarray` if `type_spec` is a `computation_types.TensorType` and a list - of such arrays if `type_spec` is `computation_types.StructType`. + `ndarray` if `type_spec` is a `federated_language.TensorType` and a list + of such arrays if `type_spec` is `federated_language.StructType`. This data structure is of the minimal size necessary in order to be compatible with `type_spec`. """ type_spec = tensorflow_types.to_type(type_spec) - def _predicate(type_spec: computation_types.Type) -> bool: + def _predicate(type_spec: federated_language.Type) -> bool: return isinstance( type_spec, ( - computation_types.TensorType, - computation_types.StructType, + federated_language.TensorType, + federated_language.StructType, ), ) - if not type_analysis.contains_only(type_spec, _predicate): + if not federated_language.framework.type_contains_only(type_spec, _predicate): raise ValueError( 'Cannot construct array for TFF type containing anything ' - 'other than `computation_types.TensorType` or ' - '`computation_types.StructType`; you have passed the ' + 'other than `federated_language.TensorType` or ' + '`federated_language.StructType`; you have passed the ' 'type {}'.format(type_spec) ) py_typecheck.check_type(none_dim_replacement, int) @@ -679,12 +677,12 @@ def _handle_none_dimension(x): return none_dim_replacement return x - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): whimsy_shape = [_handle_none_dimension(x) for x in type_spec.shape] if type_spec.dtype == np.str_: return np.empty(whimsy_shape, dtype=np.str_) return np.zeros(whimsy_shape, type_spec.dtype) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = structure.to_elements(type_spec) elem_list = [] for _, elem_type in elements: @@ -728,14 +726,14 @@ def _append_to_list_structure_for_element_type_spec(nested, value, type_spec): 'Expected an anonymous tuple to either have all elements named or ' 'all unnamed, got {}.'.format(value) ) - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): if not isinstance(nested, list): raise TypeError(f'Expected `nested` to be a `list`, found {type(nested)}') # Convert the members to tensors to ensure that they are properly # typed and grouped before being passed to # tf.data.Dataset.from_tensor_slices. nested.append(tf.convert_to_tensor(value, type_spec.dtype)) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = structure.to_elements(type_spec) if isinstance(nested, collections.OrderedDict): if isinstance(value, py_typecheck.SupportsNamedTuple): @@ -807,14 +805,14 @@ def _replace_empty_leaf_lists_with_numpy_arrays(lists, type_spec): `lists` is not of a type compatible with `type_spec`. """ type_spec = tensorflow_types.to_type(type_spec) - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): + py_typecheck.check_type(type_spec, federated_language.Type) + if isinstance(type_spec, federated_language.TensorType): py_typecheck.check_type(lists, list) if lists: return lists else: return np.array([], dtype=type_spec.dtype) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = structure.to_elements(type_spec) if isinstance(lists, collections.OrderedDict): to_return = [] @@ -880,7 +878,7 @@ def make_data_set_from_elements(graph, elements, element_type): raise ValueError('Only in eager context may the graph be `None`.') py_typecheck.check_type(elements, list) element_type = tensorflow_types.to_type(element_type) - py_typecheck.check_type(element_type, computation_types.Type) + py_typecheck.check_type(element_type, federated_language.Type) def _make(element_subset): lists = _make_empty_list_structure_for_element_type_spec(element_type) @@ -1015,7 +1013,7 @@ def _interleave_dataset_results_and_tensors(dataset_results, flat_run_tensors): def coerce_dataset_elements_to_tff_type_spec( - dataset: tf.data.Dataset, element_type: computation_types.Type + dataset: tf.data.Dataset, element_type: federated_language.Type ) -> tf.data.Dataset: """Map the elements of a dataset to a specified type. @@ -1035,10 +1033,10 @@ def coerce_dataset_elements_to_tff_type_spec( ValueError: if the elements of `dataset` cannot be coerced into `element_type`. """ - py_typecheck.check_type(element_type, computation_types.Type) - if isinstance(element_type, computation_types.TensorType): + py_typecheck.check_type(element_type, federated_language.Type) + if isinstance(element_type, federated_language.TensorType): return dataset - elif isinstance(element_type, computation_types.StructWithPythonType): + elif isinstance(element_type, federated_language.StructWithPythonType): py_type = element_type.python_container if py_type is tf.RaggedTensor or py_type is tf.sparse.SparseTensor: return dataset @@ -1047,9 +1045,9 @@ def coerce_dataset_elements_to_tff_type_spec( # look for opportunities to consolidate? def _to_representative_value(type_spec, elements): """Convert to a container to a type understood by TF and TFF.""" - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): return elements - elif isinstance(type_spec, computation_types.StructWithPythonType): + elif isinstance(type_spec, federated_language.StructWithPythonType): if tf.is_tensor(elements): # In this case we have a singleton tuple tensor that may have been # unwrapped by tf.data. @@ -1073,7 +1071,7 @@ def _to_representative_value(type_spec, elements): if isinstance(py_type, py_typecheck.SupportsNamedTuple): return py_type(*values) return py_type(values) # pylint: disable=too-many-function-args - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): field_types = structure.to_elements(type_spec) is_all_named = all([name is not None for name, _ in field_types]) if is_all_named: @@ -1184,7 +1182,9 @@ def deserialize_and_call_tf_computation( ) py_typecheck.check_type(graph, tf.Graph) with graph.as_default(): - type_spec = type_serialization.deserialize_type(computation_proto.type) + type_spec = federated_language.framework.deserialize_type( + computation_proto.type + ) if type_spec.parameter is None: # pytype: disable=attribute-error if arg is None: input_map = {} diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils_test.py index 8e07641383..77117e9af3 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils_test.py @@ -16,18 +16,16 @@ from absl.testing import absltest import attrs +import federated_language +from federated_language.proto import computation_pb2 as pb import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import type_test_utils class GraphUtilsTest(tf.test.TestCase): @@ -37,7 +35,7 @@ def _assert_binding_matches_type_and_value( ): """Asserts that 'bindings' matches the given type, value, and graph.""" self.assertIsInstance(binding, pb.TensorFlow.Binding) - self.assertIsInstance(type_spec, computation_types.Type) + self.assertIsInstance(type_spec, federated_language.Type) binding_oneof = binding.WhichOneof('binding') if binding_oneof == 'tensor': self.assertTrue(tf.is_tensor(val)) @@ -49,7 +47,7 @@ def _assert_binding_matches_type_and_value( else: # Input binding names are expected to match self.assertEqual(binding.tensor.tensor_name, val.name) - self.assertIsInstance(type_spec, computation_types.TensorType) + self.assertIsInstance(type_spec, federated_language.TensorType) self.assertEqual(type_spec.dtype, val.dtype.base_dtype) self.assertEqual(type_spec.shape, val.shape) elif binding_oneof == 'sequence': @@ -62,13 +60,13 @@ def _assert_binding_matches_type_and_value( op = str(variant_tensor.op.type) self.assertTrue((op == 'Placeholder') or (op == 'Identity')) self.assertEqual(variant_tensor.dtype, tf.variant) - self.assertIsInstance(type_spec, computation_types.SequenceType) + self.assertIsInstance(type_spec, federated_language.SequenceType) self.assertEqual( tensorflow_types.to_type(val.element_spec), type_spec.element, ) elif binding_oneof == 'struct': - self.assertIsInstance(type_spec, computation_types.StructType) + self.assertIsInstance(type_spec, federated_language.StructType) if not isinstance(val, (list, tuple, structure.Struct)): self.assertIsInstance(val, dict) val = list(val.values()) @@ -94,7 +92,7 @@ def _assert_output_binding_matches_type_and_value( ) def _assert_captured_result_eq_dtype(self, type_spec, binding, dtype): - self.assertIsInstance(type_spec, computation_types.TensorType) + self.assertIsInstance(type_spec, federated_language.TensorType) self.assertEqual(str(type_spec), dtype) self.assertEqual(binding.WhichOneof('binding'), 'tensor') @@ -153,7 +151,7 @@ def test_stamp_parameter_in_graph_with_struct(self): with tf.Graph().as_default() as my_graph: x = self._checked_stamp_parameter( 'foo', - computation_types.StructType([ + federated_language.StructType([ ('a', np.int32), ('b', np.bool_), ]), @@ -167,7 +165,7 @@ def test_stamp_parameter_in_graph_with_struct_with_python_type(self): with tf.Graph().as_default() as my_graph: x = self._checked_stamp_parameter( 'foo', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.int32), ('b', np.bool_), @@ -183,7 +181,7 @@ def test_stamp_parameter_in_graph_with_struct_with_python_type(self): def test_stamp_parameter_in_graph_with_bool_sequence(self): with tf.Graph().as_default(): x = self._checked_stamp_parameter( - 'foo', computation_types.SequenceType(np.bool_) + 'foo', federated_language.SequenceType(np.bool_) ) self.assertIsInstance(x, tf.data.Dataset) self.assertEqual(x.element_spec, tf.TensorSpec(shape=(), dtype=tf.bool)) @@ -191,7 +189,7 @@ def test_stamp_parameter_in_graph_with_bool_sequence(self): def test_stamp_parameter_in_graph_with_int_vector_sequence(self): with tf.Graph().as_default(): x = self._checked_stamp_parameter( - 'foo', computation_types.SequenceType((np.int32, [50])) + 'foo', federated_language.SequenceType((np.int32, [50])) ) self.assertIsInstance(x, tf.data.Dataset) self.assertEqual( @@ -202,7 +200,7 @@ def test_stamp_parameter_in_graph_with_tensor_ordered_dict_sequence(self): with tf.Graph().as_default(): x = self._checked_stamp_parameter( 'foo', - computation_types.SequenceType( + federated_language.SequenceType( collections.OrderedDict( [('A', (np.float32, [3, 4, 5])), ('B', (np.int32, [1]))] ) @@ -291,15 +289,18 @@ def test_capture_result_with_ragged_tensor(self): tf.RaggedTensor.from_row_splits([0, 0, 0, 0], [0, 1, 4]), graph ) del binding - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( type_spec, - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('flat_values', computation_types.TensorType(np.int32, [4])), + ('flat_values', federated_language.TensorType(np.int32, [4])), ( 'nested_row_splits', - computation_types.StructWithPythonType( - [(None, computation_types.TensorType(np.int64, [3]))], + federated_language.StructWithPythonType( + [( + None, + federated_language.TensorType(np.int64, [3]), + )], tuple, ), ), @@ -314,13 +315,13 @@ def test_capture_result_with_sparse_tensor(self): tf.SparseTensor(indices=[[1]], values=[2], dense_shape=[5]), graph ) del binding - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( type_spec, - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('indices', computation_types.TensorType(np.int64, [1, 1])), - ('values', computation_types.TensorType(np.int32, [1])), - ('dense_shape', computation_types.TensorType(np.int64, [1])), + ('indices', federated_language.TensorType(np.int64, [1, 1])), + ('values', federated_language.TensorType(np.int32, [1])), + ('dense_shape', federated_language.TensorType(np.int64, [1])), ], tf.SparseTensor, ), @@ -435,8 +436,8 @@ def test_capture_result_with_struct_of_constants(self): ]) ) self.assertEqual(str(t), '') - self.assertIsInstance(t, computation_types.StructType) - self.assertNotIsInstance(t, computation_types.StructWithPythonType) + self.assertIsInstance(t, federated_language.StructType) + self.assertNotIsInstance(t, federated_language.StructWithPythonType) @tensorflow_test_utils.graph_mode_test def test_capture_result_with_nested_lists_and_tuples(self): @@ -453,11 +454,11 @@ def test_capture_result_with_nested_lists_and_tuples(self): ]) ) self.assertEqual(str(t), '>,b=>,<>>') - self.assertIsInstance(t, computation_types.StructType) - self.assertNotIsInstance(t, computation_types.StructWithPythonType) - self.assertIsInstance(t.x, computation_types.StructWithPythonType) + self.assertIsInstance(t, federated_language.StructType) + self.assertNotIsInstance(t, federated_language.StructWithPythonType) + self.assertIsInstance(t.x, federated_language.StructWithPythonType) self.assertIs(t.x.python_container, named_tuple_type) - self.assertIsInstance(t[1], computation_types.StructWithPythonType) + self.assertIsInstance(t[1], federated_language.StructWithPythonType) self.assertIs(t[1].python_container, list) @tensorflow_test_utils.graph_mode_test @@ -551,7 +552,7 @@ def test_assemble_result_from_graph_with_named_tuple(self): @tensorflow_test_utils.graph_mode_test def test_assemble_result_from_graph_with_sequence_of_odicts(self): - type_spec = computation_types.SequenceType( + type_spec = federated_language.SequenceType( collections.OrderedDict([('X', np.int32), ('Y', np.int32)]) ) binding = pb.TensorFlow.Binding( @@ -576,7 +577,7 @@ def test_assemble_result_from_graph_with_sequence_of_odicts(self): @tensorflow_test_utils.graph_mode_test def test_assemble_result_from_graph_with_sequence_of_namedtuples(self): named_tuple_type = collections.namedtuple('TestNamedTuple', 'X Y') - type_spec = computation_types.SequenceType( + type_spec = federated_language.SequenceType( named_tuple_type(np.int32, np.int32) ) binding = pb.TensorFlow.Binding( @@ -599,7 +600,7 @@ def test_assemble_result_from_graph_with_sequence_of_namedtuples(self): ) def test__make_whimsy_element_for_type_spec_raises_sequence_type(self): - type_spec = computation_types.SequenceType(np.float32) + type_spec = federated_language.SequenceType(np.float32) with self.assertRaisesRegex( ValueError, 'Cannot construct array for TFF type' ): @@ -612,7 +613,7 @@ def test__make_whimsy_element_for_type_spec_raises_negative_none_dim_replacement tensorflow_utils._make_whimsy_element_for_type_spec(tf.float32, -1) def test_make_whimsy_element_tensor_type(self): - type_spec = computation_types.TensorType( + type_spec = federated_language.TensorType( np.float32, [None, 10, None, 10, 10] ) elem = tensorflow_utils._make_whimsy_element_for_type_spec(type_spec) @@ -620,7 +621,7 @@ def test_make_whimsy_element_tensor_type(self): self.assertAllClose(elem, correct_elem) def test_make_whimsy_element_tensor_type_none_replaced_by_1(self): - type_spec = computation_types.TensorType( + type_spec = federated_language.TensorType( np.float32, [None, 10, None, 10, 10] ) elem = tensorflow_utils._make_whimsy_element_for_type_spec( @@ -630,10 +631,12 @@ def test_make_whimsy_element_tensor_type_none_replaced_by_1(self): self.assertAllClose(elem, correct_elem) def test_make_whimsy_element_struct_type(self): - tensor1 = computation_types.TensorType(np.float32, [None, 10, None, 10, 10]) - tensor2 = computation_types.TensorType(np.int32, [10, None, 10]) - namedtuple = computation_types.StructType([('x', tensor1), ('y', tensor2)]) - unnamedtuple = computation_types.StructType( + tensor1 = federated_language.TensorType( + np.float32, [None, 10, None, 10, 10] + ) + tensor2 = federated_language.TensorType(np.int32, [10, None, 10]) + namedtuple = federated_language.StructType([('x', tensor1), ('y', tensor2)]) + unnamedtuple = federated_language.StructType( [('x', tensor1), ('y', tensor2)] ) elem = tensorflow_utils._make_whimsy_element_for_type_spec(namedtuple) @@ -698,7 +701,7 @@ def test_make_data_set_from_elements_with_empty_list_definite_tensor(self): ds = tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [], - computation_types.TensorType(np.float32, [None, 10]), + federated_language.TensorType(np.float32, [None, 10]), ) self.assertIsInstance(ds, tf.data.Dataset) self.assertEqual( @@ -714,8 +717,8 @@ def test_make_data_set_from_elements_with_empty_list_definite_tuple(self): tf.compat.v1.get_default_graph(), [], [ - computation_types.TensorType(np.float32, [None, 10]), - computation_types.TensorType(np.float32, [None, 5]), + federated_language.TensorType(np.float32, [None, 10]), + federated_language.TensorType(np.float32, [None, 5]), ], ) self.assertIsInstance(ds, tf.data.Dataset) @@ -1066,7 +1069,7 @@ def test_list_structures_from_element_type_spec_with_int_value(self): def test_list_structures_from_element_type_spec_with_empty_dict_value(self): self._test_list_structure( - computation_types.StructType([]), [{}], 'OrderedDict()' + federated_language.StructType([]), [{}], 'OrderedDict()' ) def test_list_structures_from_element_type_spec_with_dict_value(self): @@ -1100,12 +1103,12 @@ def test_list_structures_from_element_type_spec_with_int_values(self): def test_list_structures_from_element_type_spec_with_empty_dict_values(self): self._test_list_structure( - computation_types.StructType([]), [{}, {}, {}], 'OrderedDict()' + federated_language.StructType([]), [{}, {}, {}], 'OrderedDict()' ) def test_list_structures_from_element_type_spec_with_structures(self): self._test_list_structure( - computation_types.StructType([('a', np.int32)]), + federated_language.StructType([('a', np.int32)]), [structure.Struct([('a', 1)]), structure.Struct([('a', 2)])], ( "OrderedDict([('a', [" @@ -1116,15 +1119,15 @@ def test_list_structures_from_element_type_spec_with_structures(self): def test_list_structures_from_element_type_spec_with_empty_anon_tuples(self): self._test_list_structure( - computation_types.StructType([]), + federated_language.StructType([]), [structure.Struct([]), structure.Struct([])], 'OrderedDict()', ) def test_list_structures_from_element_type_spec_w_list_of_anon_tuples(self): self._test_list_structure( - computation_types.StructType( - [computation_types.StructType([('a', np.int32)])] + federated_language.StructType( + [federated_language.StructType([('a', np.int32)])] ), [[structure.Struct([('a', 1)])], [structure.Struct([('a', 2)])]], ( @@ -1144,12 +1147,12 @@ def test_make_data_set_from_elements_with_odd_last_batch(self): tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [np.array([1, 2]), np.array([3])], - computation_types.TensorType(np.int32, (None,)), + federated_language.TensorType(np.int32, (None,)), ) tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [{'x': np.array([1, 2])}, {'x': np.array([3])}], - [('x', computation_types.TensorType(np.int32, (None,)))], + [('x', federated_language.TensorType(np.int32, (None,)))], ) def test_make_data_set_from_elements_with_odd_all_batches(self): @@ -1161,7 +1164,7 @@ def test_make_data_set_from_elements_with_odd_all_batches(self): np.array([4, 5, 6]), np.array([7, 8]), ], - computation_types.TensorType(np.int32, (None,)), + federated_language.TensorType(np.int32, (None,)), ) tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), @@ -1171,19 +1174,19 @@ def test_make_data_set_from_elements_with_odd_all_batches(self): {'x': np.array([4, 5, 6])}, {'x': np.array([7, 8])}, ], - [('x', computation_types.TensorType(np.int32, (None,)))], + [('x', federated_language.TensorType(np.int32, (None,)))], ) def test_make_data_set_from_elements_with_just_one_batch(self): tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [np.array([1])], - computation_types.TensorType(np.int32, (None,)), + federated_language.TensorType(np.int32, (None,)), ) tensorflow_utils.make_data_set_from_elements( tf.compat.v1.get_default_graph(), [{'x': np.array([1])}], - [('x', computation_types.TensorType(np.int32, (None,)))], + [('x', federated_language.TensorType(np.int32, (None,)))], ) def test_make_dataset_from_variant_tensor_constructs_dataset(self): @@ -1213,7 +1216,7 @@ def test_make_dataset_from_variant_tensor_fails_with_bad_type(self): def test_coerce_dataset_elements_noop(self): x = tf.data.Dataset.range(5) y = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( - x, computation_types.TensorType(np.int64) + x, federated_language.TensorType(np.int64) ) self.assertEqual(x.element_spec, y.element_spec) @@ -1222,13 +1225,13 @@ def test_coerce_ragged_tensor_dataset_elements_noop(self): values=[3, 1, 4], row_splits=[0, 2, 2, 3] ) dataset = tf.data.Dataset.from_tensors(ragged_tensor) - element_type = computation_types.StructWithPythonType( + element_type = federated_language.StructWithPythonType( [ - ('flat_values', computation_types.TensorType(np.int32)), + ('flat_values', federated_language.TensorType(np.int32)), ( 'nested_row_splits', - computation_types.StructWithPythonType( - [computation_types.TensorType(np.int64, [None])], tuple + federated_language.StructWithPythonType( + [federated_language.TensorType(np.int64, [None])], tuple ), ), ], @@ -1244,14 +1247,14 @@ def test_coerce_sparse_tensor_dataset_elements_noop(self): indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4] ) dataset = tf.data.Dataset.from_tensors(sparse_tensor) - element_type = computation_types.StructWithPythonType( + element_type = federated_language.StructWithPythonType( [ ( 'indices', - computation_types.TensorType(np.int64, shape=[None, 2]), + federated_language.TensorType(np.int64, shape=[None, 2]), ), - ('values', computation_types.TensorType(np.int32, shape=[None])), - ('dense_shape', computation_types.TensorType(np.int64, shape=[2])), + ('values', federated_language.TensorType(np.int32, shape=[None])), + ('dense_shape', federated_language.TensorType(np.int64, shape=[2])), ], tf.sparse.SparseTensor, ) @@ -1276,15 +1279,15 @@ def _make_nested_tf_structure(x): x = tf.data.Dataset.range(5).map(_make_nested_tf_structure) - element_type = computation_types.StructType([ + element_type = federated_language.StructType([ ( 'a', - computation_types.StructType([ + federated_language.StructType([ (None, np.int64), (None, test_tuple_type(np.int64, np.int64)), ( None, - computation_types.StructType( + federated_language.StructType( [('x', np.int64), ('y', np.int64)] ), ), @@ -1313,14 +1316,18 @@ def test_deserialize_and_call_tf_computation_with_add_one(self): result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph ) - parameter_type = computation_types.TensorType(np.int32) - type_signature = computation_types.FunctionType(parameter_type, result_type) + parameter_type = federated_language.TensorType(np.int32) + type_signature = federated_language.FunctionType( + parameter_type, result_type + ) tensorflow_proto = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, ) - serialized_type = type_serialization.serialize_type(type_signature) + serialized_type = federated_language.framework.serialize_type( + type_signature + ) computation_proto = pb.Computation( type=serialized_type, tensorflow=tensorflow_proto ) @@ -1356,17 +1363,21 @@ def test_deserialize_and_call_tf_computation_with_placeholder_replacement( result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph ) - parameter_type = computation_types.StructType([ - (None, computation_types.TensorType(np.int32)), - (None, computation_types.TensorType(np.int32)), + parameter_type = federated_language.StructType([ + (None, federated_language.TensorType(np.int32)), + (None, federated_language.TensorType(np.int32)), ]) - type_signature = computation_types.FunctionType(parameter_type, result_type) + type_signature = federated_language.FunctionType( + parameter_type, result_type + ) tensorflow_proto = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, ) - serialized_type = type_serialization.serialize_type(type_signature) + serialized_type = federated_language.framework.serialize_type( + type_signature + ) computation_proto = pb.Computation( type=serialized_type, tensorflow=tensorflow_proto ) @@ -1393,14 +1404,18 @@ def test_deserialize_and_call_tf_computation_returning_session_token(self): result, graph ) parameter_type = None - type_signature = computation_types.FunctionType(parameter_type, result_type) + type_signature = federated_language.FunctionType( + parameter_type, result_type + ) tensorflow_proto = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, session_token_tensor_name=session_token_placeholder.name, result=result_binding, ) - serialized_type = type_serialization.serialize_type(type_signature) + serialized_type = federated_language.framework.serialize_type( + type_signature + ) computation_proto = pb.Computation( type=serialized_type, tensorflow=tensorflow_proto ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions.py b/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions.py index 99427db3a7..38e0aaf076 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions.py @@ -17,34 +17,32 @@ from typing import Optional import attrs +import federated_language import tensorflow as tf import tree from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import typed_object -def _tensor_to_type(tensor: tf.Tensor) -> computation_types.Type: +def _tensor_to_type(tensor: tf.Tensor) -> federated_language.Type: """Returns a `tff.Type` for the `tensor`.""" return tensorflow_types.to_type((tensor.dtype, tensor.shape)) -def _variable_to_type(variable: tf.Variable) -> computation_types.Type: +def _variable_to_type(variable: tf.Variable) -> federated_language.Type: """Returns a `tff.Type` for the `variable`.""" return tensorflow_types.to_type((variable.dtype, variable.shape)) -def _dataset_to_type(dataset: tf.data.Dataset) -> computation_types.Type: +def _dataset_to_type(dataset: tf.data.Dataset) -> federated_language.Type: """Returns a `tff.Type` for the `dataset`.""" dataset_spec = tf.data.DatasetSpec.from_value(dataset) return tensorflow_types.to_type(dataset_spec) -def tensorflow_infer_type(obj: object) -> Optional[computation_types.Type]: +def tensorflow_infer_type(obj: object) -> Optional[federated_language.Type]: """Returns a `tff.Type` for an `obj` containing TensorFlow values. This function extends `type_conversions.infer_type` to handle TensorFlow @@ -74,13 +72,13 @@ def tensorflow_infer_type(obj: object) -> Optional[computation_types.Type]: obj: An object to infer a `tff.Type`. """ - class _Placeholder(typed_object.TypedObject): + class _Placeholder(federated_language.TypedObject): - def __init__(self, type_signature: computation_types.Type): + def __init__(self, type_signature: federated_language.Type): self._type_signature = type_signature @property - def type_signature(self) -> computation_types.Type: + def type_signature(self) -> federated_language.Type: return self._type_signature def _infer_type(obj): @@ -101,10 +99,10 @@ def _infer_type(obj): return None partial = tree.traverse(_infer_type, obj) - return type_conversions.infer_type(partial) + return federated_language.framework.infer_type(partial) -def _type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): +def _type_to_tf_dtypes_and_shapes(type_spec: federated_language.Type): """Returns nested structures of tensor dtypes and shapes for a given TFF type. The returned dtypes and shapes match those used by `tf.data.Dataset`s to @@ -112,7 +110,7 @@ def _type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): arguments in constructing an iterator over a string handle. Args: - type_spec: A `computation_types.Type`, the type specification must be + type_spec: A `federated_language.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. @@ -126,11 +124,11 @@ def _type_to_tf_dtypes_and_shapes(type_spec: computation_types.Type): ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): + py_typecheck.check_type(type_spec, federated_language.Type) + if isinstance(type_spec, federated_language.TensorType): shape = tf.TensorShape(type_spec.shape) return (type_spec.dtype, shape) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = structure.to_elements(type_spec) if not elements: output_dtypes = [] @@ -194,7 +192,7 @@ def build_py_container(elements): ) -def type_to_tf_tensor_specs(type_spec: computation_types.Type): +def type_to_tf_tensor_specs(type_spec: federated_language.Type): """Returns nested structure of `tf.TensorSpec`s for a given TFF type. The dtypes and shapes of the returned `tf.TensorSpec`s match those used by @@ -202,7 +200,7 @@ def type_to_tf_tensor_specs(type_spec: computation_types.Type): be used, e.g., as arguments in constructing an iterator over a string handle. Args: - type_spec: A `computation_types.Type`, the type specification must be + type_spec: A `federated_language.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. @@ -218,11 +216,11 @@ def type_to_tf_tensor_specs(type_spec: computation_types.Type): ) -def type_to_tf_structure(type_spec: computation_types.Type): +def type_to_tf_structure(type_spec: federated_language.Type): """Returns nested `tf.data.experimental.Structure` for a given TFF type. Args: - type_spec: A `computation_types.Type`, the type specification must be + type_spec: A `federated_language.Type`, the type specification must be composed of only named tuples and tensors. In all named tuples that appear in the type spec, all the elements must be named. @@ -234,10 +232,10 @@ def type_to_tf_structure(type_spec: computation_types.Type): ValueError: if the `type_spec` is composed of something other than named tuples and tensors, or if any of the elements in named tuples are unnamed. """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): + py_typecheck.check_type(type_spec, federated_language.Type) + if isinstance(type_spec, federated_language.TensorType): return tf.TensorSpec(type_spec.shape, type_spec.dtype) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): elements = structure.to_elements(type_spec) if not elements: return () @@ -269,10 +267,10 @@ def type_to_tf_structure(type_spec: computation_types.Type): def _structure_from_tensor_type_tree_inner( - fn, type_spec: computation_types.Type + fn, type_spec: federated_language.Type ): """Helper for `structure_from_tensor_type_tree`.""" - if isinstance(type_spec, computation_types.StructType): + if isinstance(type_spec, federated_language.StructType): def _map_element(element): name, nested_type = element return (name, _structure_from_tensor_type_tree_inner(fn, nested_type)) @@ -280,7 +278,7 @@ def _map_element(element): return structure.Struct( map(_map_element, structure.iter_elements(type_spec)) ) - elif isinstance(type_spec, computation_types.TensorType): + elif isinstance(type_spec, federated_language.TensorType): return fn(type_spec) else: raise ValueError( @@ -290,7 +288,7 @@ def _map_element(element): def structure_from_tensor_type_tree( - fn: Callable[[computation_types.TensorType], object], type_spec + fn: Callable[[federated_language.TensorType], object], type_spec ) -> object: """Constructs a structure from a `type_spec` tree of `tff.TensorType`s. @@ -309,6 +307,8 @@ def structure_from_tensor_type_tree( Raises: ValueError: if the provided `type_spec` is not a structural or tensor type. """ - type_spec = computation_types.to_type(type_spec) + type_spec = federated_language.to_type(type_spec) non_python_typed = _structure_from_tensor_type_tree_inner(fn, type_spec) - return type_conversions.type_to_py_container(non_python_typed, type_spec) + return federated_language.framework.type_to_py_container( + non_python_typed, type_spec + ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions_test.py b/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions_test.py index 5a19147c66..3cfd1aa51e 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions_test.py @@ -17,23 +17,20 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions as tensorflow_type_conversions -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import type_test_utils -from tensorflow_federated.python.core.impl.types import typed_object -class _TestTypedObject(typed_object.TypedObject): +class _TestTypedObject(federated_language.TypedObject): - def __init__(self, type_signature: computation_types.Type): + def __init__(self, type_signature: federated_language.Type): self._type_signature = type_signature @property - def type_signature(self) -> computation_types.Type: + def type_signature(self) -> federated_language.Type: return self._type_signature @@ -43,14 +40,14 @@ class TensorflowInferTypeTest(parameterized.TestCase): ( 'tensor', tf.ones(shape=[2, 3], dtype=tf.int32), - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ), ( 'tensor_nested', [tf.ones(shape=[2, 3], dtype=tf.int32)], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ], list, ), @@ -58,10 +55,10 @@ class TensorflowInferTypeTest(parameterized.TestCase): ( 'tensor_mixed', [tf.ones(shape=[2, 3], dtype=tf.int32), 1.0], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.float32), ], list, ), @@ -69,14 +66,14 @@ class TensorflowInferTypeTest(parameterized.TestCase): ( 'variable', tf.Variable(tf.ones(shape=[2, 3], dtype=tf.int32)), - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ), ( 'variable_nested', [tf.Variable(tf.ones(shape=[2, 3], dtype=tf.int32))], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ], list, ), @@ -84,10 +81,10 @@ class TensorflowInferTypeTest(parameterized.TestCase): ( 'variable_mixed', [tf.Variable(tf.ones(shape=[2, 3], dtype=tf.int32)), 1.0], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.float32), ], list, ), @@ -95,17 +92,17 @@ class TensorflowInferTypeTest(parameterized.TestCase): ( 'dataset', tf.data.Dataset.from_tensors(tf.ones(shape=[2, 3], dtype=tf.int32)), - computation_types.SequenceType( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.SequenceType( + federated_language.TensorType(np.int32, shape=[2, 3]) ), ), ( 'dataset_nested', [tf.data.Dataset.from_tensors(tf.ones(shape=[2, 3], dtype=tf.int32))], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.SequenceType( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.SequenceType( + federated_language.TensorType(np.int32, shape=[2, 3]) ), ], list, @@ -119,12 +116,12 @@ class TensorflowInferTypeTest(parameterized.TestCase): ), 1.0, ], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.SequenceType( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.SequenceType( + federated_language.TensorType(np.int32, shape=[2, 3]) ), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), ], list, ), @@ -139,7 +136,7 @@ def test_returns_result_with_tensorflow_obj(self, obj, expected_result): ( 'typed_object', _TestTypedObject( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.TensorType(np.int32, shape=[2, 3]) ), ), ('int', 1), @@ -151,7 +148,7 @@ def test_returns_result_with_tensorflow_obj(self, obj, expected_result): def test_delegates_result_with_obj(self, obj): with mock.patch.object( - type_conversions, 'infer_type', autospec=True, spec_set=True + federated_language.framework, 'infer_type', autospec=True, spec_set=True ) as mock_infer_type: tensorflow_type_conversions.tensorflow_infer_type(obj) mock_infer_type.assert_called_once_with(obj) @@ -160,7 +157,7 @@ def test_delegates_result_with_obj(self, obj): class TypeToTfDtypesAndShapesTest(absltest.TestCase): def test_with_int_scalar(self): - type_signature = computation_types.TensorType(np.int32) + type_signature = federated_language.TensorType(np.int32) dtypes, shapes = tensorflow_type_conversions._type_to_tf_dtypes_and_shapes( type_signature ) @@ -168,7 +165,7 @@ def test_with_int_scalar(self): self.assertEqual(shapes, ()) def test_with_int_vector(self): - type_signature = computation_types.TensorType(np.int32, [10]) + type_signature = federated_language.TensorType(np.int32, [10]) dtypes, shapes = tensorflow_type_conversions._type_to_tf_dtypes_and_shapes( type_signature ) @@ -176,11 +173,11 @@ def test_with_int_vector(self): self.assertEqual(shapes, (10,)) def test_with_tensor_triple(self): - type_signature = computation_types.StructWithPythonType( + type_signature = federated_language.StructWithPythonType( [ - ('a', computation_types.TensorType(np.int32, [5])), - ('b', computation_types.TensorType(np.bool_)), - ('c', computation_types.TensorType(np.float32, [3])), + ('a', federated_language.TensorType(np.int32, [5])), + ('b', federated_language.TensorType(np.bool_)), + ('c', federated_language.TensorType(np.float32, [3])), ], collections.OrderedDict, ) @@ -191,20 +188,20 @@ def test_with_tensor_triple(self): self.assertEqual(shapes, {'a': [5], 'b': [], 'c': [3]}) def test_with_two_level_tuple(self): - type_signature = computation_types.StructWithPythonType( + type_signature = federated_language.StructWithPythonType( [ ('a', np.bool_), ( 'b', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('c', computation_types.TensorType(np.float32)), - ('d', computation_types.TensorType(np.int32, [20])), + ('c', federated_language.TensorType(np.float32)), + ('d', federated_language.TensorType(np.int32, [20])), ], collections.OrderedDict, ), ), - ('e', computation_types.StructType([])), + ('e', federated_language.StructType([])), ], collections.OrderedDict, ) @@ -220,25 +217,25 @@ def test_with_two_level_tuple(self): class TypeToTfTensorSpecsTest(absltest.TestCase): def test_with_int_scalar(self): - type_signature = computation_types.TensorType(np.int32) + type_signature = federated_language.TensorType(np.int32) tensor_specs = tensorflow_type_conversions.type_to_tf_tensor_specs( type_signature ) self.assertEqual(tensor_specs, tf.TensorSpec([], np.int32)) def test_with_int_vector(self): - type_signature = computation_types.TensorType(np.int32, [10]) + type_signature = federated_language.TensorType(np.int32, [10]) tensor_specs = tensorflow_type_conversions.type_to_tf_tensor_specs( type_signature ) self.assertEqual(tensor_specs, tf.TensorSpec([10], np.int32)) def test_with_tensor_triple(self): - type_signature = computation_types.StructWithPythonType( + type_signature = federated_language.StructWithPythonType( [ - ('a', computation_types.TensorType(np.int32, [5])), - ('b', computation_types.TensorType(np.bool_)), - ('c', computation_types.TensorType(np.float32, [3])), + ('a', federated_language.TensorType(np.int32, [5])), + ('b', federated_language.TensorType(np.bool_)), + ('c', federated_language.TensorType(np.float32, [3])), ], collections.OrderedDict, ) @@ -255,20 +252,20 @@ def test_with_tensor_triple(self): ) def test_with_two_level_tuple(self): - type_signature = computation_types.StructWithPythonType( + type_signature = federated_language.StructWithPythonType( [ ('a', np.bool_), ( 'b', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('c', computation_types.TensorType(np.float32)), - ('d', computation_types.TensorType(np.int32, [20])), + ('c', federated_language.TensorType(np.float32)), + ('d', federated_language.TensorType(np.int32, [20])), ], collections.OrderedDict, ), ), - ('e', computation_types.StructType([])), + ('e', federated_language.StructType([])), ], collections.OrderedDict, ) @@ -292,7 +289,7 @@ def test_with_invalid_type(self): tensorflow_type_conversions.type_to_tf_tensor_specs(np.float32(0.0)) def test_with_unnamed_element(self): - type_signature = computation_types.StructType([np.int32]) + type_signature = federated_language.StructType([np.int32]) tensor_specs = tensorflow_type_conversions.type_to_tf_tensor_specs( type_signature ) @@ -312,15 +309,15 @@ def test_with_names(self): ]), ), ]) - type_spec = computation_types.StructWithPythonType( + type_spec = federated_language.StructWithPythonType( [ - ('a', computation_types.TensorType(np.bool_)), + ('a', federated_language.TensorType(np.bool_)), ( 'b', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('c', computation_types.TensorType(np.float32)), - ('d', computation_types.TensorType(np.int32, (20,))), + ('c', federated_language.TensorType(np.float32)), + ('d', federated_language.TensorType(np.int32, (20,))), ], collections.OrderedDict, ), @@ -341,7 +338,7 @@ def test_without_names(self): tf.TensorSpec(shape=(), dtype=np.bool_), tf.TensorSpec(shape=(), dtype=np.int32), ) - type_spec = computation_types.StructType([np.bool_, np.int32]) + type_spec = federated_language.StructType([np.bool_, np.int32]) tf_structure = tensorflow_type_conversions.type_to_tf_structure(type_spec) with tf.Graph().as_default(): ds = tf.data.experimental.from_variant( @@ -357,18 +354,18 @@ def test_with_none(self): def test_with_sequence_type(self): with self.assertRaises(ValueError): tensorflow_type_conversions.type_to_tf_structure( - computation_types.SequenceType(np.int32) + federated_language.SequenceType(np.int32) ) def test_with_inconsistently_named_elements(self): with self.assertRaises(ValueError): tensorflow_type_conversions.type_to_tf_structure( - computation_types.StructType([('a', np.int32), np.bool_]) + federated_language.StructType([('a', np.int32), np.bool_]) ) def test_with_no_elements(self): tf_structure = tensorflow_type_conversions.type_to_tf_structure( - computation_types.StructType([]) + federated_language.StructType([]) ) self.assertEqual(tf_structure, ()) @@ -388,8 +385,8 @@ def fn(ignored): def test_single_tensor(self): def expect_tfint32_return_5(tensor_type): - type_test_utils.assert_types_identical( - tensor_type, computation_types.TensorType(np.int32) + federated_language.framework.assert_types_identical( + tensor_type, federated_language.TensorType(np.int32) ) return 5 @@ -399,7 +396,7 @@ def expect_tfint32_return_5(tensor_type): self.assertEqual(result, 5) def test_dict(self): - struct_type = computation_types.StructWithPythonType( + struct_type = federated_language.StructWithPythonType( [('a', np.int32), ('b', np.int32)], collections.OrderedDict ) return_incr = self.get_incrementing_function() diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD b/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD index 7ffc1fc722..311a33a0a8 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD @@ -30,11 +30,7 @@ py_library( ":tensorflow_serialization", ":tensorflow_types", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:type_analysis", + "@federated_language//federated_language", ], ) @@ -44,13 +40,7 @@ py_test( srcs = ["tensorflow_computation_test.py"], deps = [ ":tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/context_stack:get_context_stack", - "//tensorflow_federated/python/core/impl/context_stack:runtime_error_context", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -62,11 +52,7 @@ py_library( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_utils", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", + "@federated_language//federated_language", ], ) @@ -77,11 +63,8 @@ py_test( deps = [ ":tensorflow_computation", ":tensorflow_computation_context", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -91,14 +74,11 @@ py_library( deps = [ ":tensorflow_computation_context", ":variable_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:serialization_utils", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_utils", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -110,9 +90,7 @@ py_test( ":tensorflow_serialization", "//tensorflow_federated/python/core/environments/tensorflow_backend:serialization_utils", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", ], ) @@ -124,10 +102,7 @@ py_library( py_library( name = "tensorflow_types", srcs = ["tensorflow_types.py"], - deps = [ - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], + deps = ["@federated_language//federated_language"], ) py_test( @@ -143,6 +118,6 @@ py_test( srcs = ["tensorflow_types_test.py"], deps = [ ":tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation.py index 745fa393a8..5d916f0071 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation.py @@ -15,17 +15,13 @@ from typing import Optional +import federated_language import tensorflow as tf import tree from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_serialization from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import type_analysis def _to_numpy(value: object) -> object: @@ -62,7 +58,9 @@ def _tf_wrapper_fn( ): """Wrapper function to plug Tensorflow logic into the TFF framework.""" del name # Unused. - if not type_analysis.is_tensorflow_compatible_type(parameter_type): + if not federated_language.framework.is_tensorflow_compatible_type( + parameter_type + ): raise TypeError( '`tff.tensorflow.computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' @@ -70,23 +68,23 @@ def _tf_wrapper_fn( 'with the type {}.'.format(parameter_type) ) - fn = function_utils.wrap_as_zero_or_one_arg_callable( + fn = federated_language.framework.wrap_as_zero_or_one_arg_callable( fn, parameter_type, unpack ) - context_stack = context_stack_impl.context_stack + context_stack = federated_language.framework.global_context_stack comp_pb, extra_type_spec = ( tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, parameter_type, context_stack ) ) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=comp_pb, context_stack=context_stack, annotated_type=extra_type_spec, ) -tf_computation = computation_wrapper.ComputationWrapper( +tf_computation = federated_language.framework.ComputationWrapper( _tf_wrapper_fn, tensorflow_types.to_type, type_conversions.tensorflow_infer_type, diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context.py index 57fa247b40..5735217141 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context.py @@ -13,6 +13,7 @@ # limitations under the License. """The implementation of a context to use in building TF computations.""" +import federated_language import tensorflow as tf import tree @@ -20,18 +21,13 @@ from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions as tensorflow_type_conversions -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions def get_session_token() -> tf.Tensor: """Returns a string tensor identifying the current session.""" - context = context_stack_impl.context_stack.current + context = federated_language.framework.global_context_stack.current if not isinstance(context, TensorFlowComputationContext): - raise context_base.ContextError( + raise federated_language.framework.ContextError( 'Session tokens can only be retrieved from within the ' '`TensorFlowComputationContext (in a `@tff.tensorflow.computation`). ' f'Instead, the context {context} of type {type(context)} was found.' @@ -39,7 +35,7 @@ def get_session_token() -> tf.Tensor: return context.session_token -class TensorFlowComputationContext(context_base.SyncContext): +class TensorFlowComputationContext(federated_language.framework.SyncContext): """The context for building TensorFlow computations.""" def __init__(self, graph, session_token): @@ -59,7 +55,7 @@ def session_token(self): """Returns a string tensor which uniquely identifies the current session.""" return self._session_token - def invoke(self, comp: computation_impl.ConcreteComputation, arg): + def invoke(self, comp: federated_language.framework.ConcreteComputation, arg): if comp.type_signature.parameter is not None: # Normalize to a Python structure to make it simpler to handle; `args` is # sometimes a `tff.structure.Struct` and sometimes it is not, other times @@ -77,18 +73,22 @@ def _to_python(obj): if not comp.type_signature.parameter.is_assignable_from(inferred_type): raise TypeError( - computation_types.type_mismatch_error_message( + federated_language.framework.type_mismatch_error_message( inferred_type, comp.type_signature.parameter, - computation_types.TypeRelation.ASSIGNABLE, + federated_language.framework.TypeRelation.ASSIGNABLE, second_is_expected=True, ) ) # We are invoking a `tff.tensorflow.computation` inside of another # `tff.tensorflow.computation`. - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) - computation_proto = computation_impl.ConcreteComputation.get_proto(comp) + py_typecheck.check_type( + comp, federated_language.framework.ConcreteComputation + ) + computation_proto = ( + federated_language.framework.ConcreteComputation.get_proto(comp) + ) computation_oneof = computation_proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise ValueError( @@ -106,6 +106,6 @@ def _to_python(obj): ) if init_op: self._init_ops.append(init_op) - return type_conversions.type_to_py_container( + return federated_language.framework.type_to_py_container( result, comp.type_signature.result ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context_test.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context_test.py index b0313d532a..f3520e4d4b 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_context_test.py @@ -13,32 +13,28 @@ # limitations under the License. from absl.testing import absltest +import federated_language +from federated_language.proto import computation_pb2 as pb import numpy as np import tensorflow as tf - -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation_context -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization class TensorFlowComputationContextTest(absltest.TestCase): def test_invoke_raises_value_error_with_federated_computation(self): bogus_proto = pb.Computation( - type=type_serialization.serialize_type( - computation_types.to_type( - computation_types.FunctionType(np.int32, np.int32) + type=federated_language.framework.serialize_type( + federated_language.to_type( + federated_language.FunctionType(np.int32, np.int32) ) ), reference=pb.Reference(name='boogledy'), ) - non_tf_computation = computation_impl.ConcreteComputation( + non_tf_computation = federated_language.framework.ConcreteComputation( computation_proto=bogus_proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) context = tensorflow_computation_context.TensorFlowComputationContext( diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py index c50eb71302..1df463120b 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py @@ -17,18 +17,12 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import ml_dtypes import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.context_stack import get_context_stack -from tensorflow_federated.python.core.impl.context_stack import runtime_error_context -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils def one_arg_fn(x): @@ -68,33 +62,35 @@ def test_returns_expected_result(self, value, expected_result): class TensorFlowComputationTest(parameterized.TestCase): @parameterized.named_parameters( - ('tensor_bool', computation_types.TensorType(np.bool_)), - ('tensor_int8', computation_types.TensorType(np.int8)), - ('tensor_int16', computation_types.TensorType(np.int16)), - ('tensor_int32', computation_types.TensorType(np.int32)), - ('tensor_int64', computation_types.TensorType(np.int64)), - ('tensor_uint8', computation_types.TensorType(np.uint8)), - ('tensor_uint16', computation_types.TensorType(np.uint16)), - ('tensor_uint32', computation_types.TensorType(np.uint32)), - ('tensor_uint64', computation_types.TensorType(np.uint64)), - ('tensor_float16', computation_types.TensorType(np.float16)), - ('tensor_float32', computation_types.TensorType(np.float32)), - ('tensor_float64', computation_types.TensorType(np.float64)), - ('tensor_complex64', computation_types.TensorType(np.complex64)), - ('tensor_complex128', computation_types.TensorType(np.complex128)), - ('tensor_bfloat16', computation_types.TensorType(ml_dtypes.bfloat16)), - ('tensor_str', computation_types.TensorType(np.str_)), - ('tensor_generic', computation_types.TensorType(np.int32)), - ('tensor_array', computation_types.TensorType(np.int32, shape=[3])), - ('sequence', computation_types.SequenceType(np.int32)), + ('tensor_bool', federated_language.TensorType(np.bool_)), + ('tensor_int8', federated_language.TensorType(np.int8)), + ('tensor_int16', federated_language.TensorType(np.int16)), + ('tensor_int32', federated_language.TensorType(np.int32)), + ('tensor_int64', federated_language.TensorType(np.int64)), + ('tensor_uint8', federated_language.TensorType(np.uint8)), + ('tensor_uint16', federated_language.TensorType(np.uint16)), + ('tensor_uint32', federated_language.TensorType(np.uint32)), + ('tensor_uint64', federated_language.TensorType(np.uint64)), + ('tensor_float16', federated_language.TensorType(np.float16)), + ('tensor_float32', federated_language.TensorType(np.float32)), + ('tensor_float64', federated_language.TensorType(np.float64)), + ('tensor_complex64', federated_language.TensorType(np.complex64)), + ('tensor_complex128', federated_language.TensorType(np.complex128)), + ('tensor_bfloat16', federated_language.TensorType(ml_dtypes.bfloat16)), + ('tensor_str', federated_language.TensorType(np.str_)), + ('tensor_generic', federated_language.TensorType(np.int32)), + ('tensor_array', federated_language.TensorType(np.int32, shape=[3])), + ('sequence', federated_language.SequenceType(np.int32)), ) def test_returns_concrete_computation_with_dtype(self, type_spec): @tensorflow_computation.tf_computation(type_spec) def _comp(x): return x - self.assertIsInstance(_comp, computation_impl.ConcreteComputation) - expected_type = computation_types.FunctionType(type_spec, type_spec) + self.assertIsInstance( + _comp, federated_language.framework.ConcreteComputation + ) + expected_type = federated_language.FunctionType(type_spec, type_spec) self.assertEqual(_comp.type_signature, expected_type) @parameterized.named_parameters( @@ -126,13 +122,13 @@ def test_tf_computation_with_type( def test_tf_computation_without_type(self, fn): fn = tensorflow_computation.tf_computation(fn) concrete_fn = fn.fn_for_argument_type( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ) self.assertEqual( concrete_fn.type_signature.compact_representation(), '(int32 -> bool)' ) concrete_fn = fn.fn_for_argument_type( - computation_types.TensorType(np.float32) + federated_language.TensorType(np.float32) ) self.assertEqual( concrete_fn.type_signature.compact_representation(), '(float32 -> bool)' @@ -153,13 +149,13 @@ def foo(x): return x > 10 concrete_fn = foo.fn_for_argument_type( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ) self.assertEqual( concrete_fn.type_signature.compact_representation(), '(int32 -> bool)' ) concrete_fn = foo.fn_for_argument_type( - computation_types.TensorType(np.float32) + federated_language.TensorType(np.float32) ) self.assertEqual( concrete_fn.type_signature.compact_representation(), '(float32 -> bool)' @@ -189,9 +185,9 @@ def foo(t): foo = tensorflow_computation.tf_computation(foo) concrete_fn = foo.fn_for_argument_type( - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), + federated_language.StructType([ + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), ]) ) self.assertEqual( @@ -199,9 +195,9 @@ def foo(t): '( -> int32)', ) concrete_fn = foo.fn_for_argument_type( - computation_types.StructType([ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.float32), + federated_language.StructType([ + federated_language.TensorType(np.float32), + federated_language.TensorType(np.float32), ]) ) self.assertEqual( @@ -258,7 +254,7 @@ def foo(x, t, l, odict, my_type): foo = tensorflow_computation.tf_computation(foo) concrete_fn = foo.fn_for_argument_type( - computation_types.to_type([ + federated_language.to_type([ np.int32, (np.int32, np.int32), [np.int32, np.int32], @@ -274,7 +270,7 @@ def foo(x, t, l, odict, my_type): ), ) concrete_fn = foo.fn_for_argument_type( - computation_types.to_type([ + federated_language.to_type([ np.float32, (np.float32, np.float32), [np.float32, np.float32], @@ -343,7 +339,7 @@ def foo(t): foo = tensorflow_computation.tf_computation(foo) concrete_fn = foo.fn_for_argument_type( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [('x', np.int32), ('y', np.int32)], MyType ) ) @@ -352,7 +348,7 @@ def foo(t): '( -> int32)', ) concrete_fn = foo.fn_for_argument_type( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [('x', np.float32), ('y', np.float32)], MyType ) ) @@ -378,11 +374,13 @@ def foo(x): ) def test_fails_with_bad_types(self): - function = computation_types.FunctionType( - None, computation_types.TensorType(np.int32) + function = federated_language.FunctionType( + None, federated_language.TensorType(np.int32) ) - federated = computation_types.FederatedType(np.int32, placements.CLIENTS) - tuple_on_function = computation_types.StructType([federated, function]) + federated = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + tuple_on_function = federated_language.StructType([federated, function]) def foo(x): del x # Unused. @@ -403,14 +401,14 @@ def foo(x): TypeError, r'you have attempted to create one with the type placement' ): tensorflow_computation.tf_computation( - foo, computation_types.PlacementType() + foo, federated_language.PlacementType() ) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type T' ): tensorflow_computation.tf_computation( - foo, computation_types.AbstractType('T') + foo, federated_language.AbstractType('T') ) with self.assertRaisesRegex( @@ -435,19 +433,21 @@ def test_error_on_non_callable_non_type(self): tensorflow_computation.tf_computation(5) def test_stack_resets_on_none_returned(self): - stack = get_context_stack.get_context_stack() + stack = federated_language.framework.get_context_stack() self.assertIsInstance( - stack.current, runtime_error_context.RuntimeErrorContext + stack.current, federated_language.framework.RuntimeErrorContext ) - with self.assertRaises(computation_wrapper.ComputationReturnedNoneError): + with self.assertRaises( + federated_language.framework.ComputationReturnedNoneError + ): @tensorflow_computation.tf_computation() def _(): pass self.assertIsInstance( - stack.current, runtime_error_context.RuntimeErrorContext + stack.current, federated_language.framework.RuntimeErrorContext ) def test_custom_numpy_dtype(self): @@ -456,11 +456,11 @@ def test_custom_numpy_dtype(self): def foo(x): return x - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( foo.type_signature, - computation_types.FunctionType( - parameter=computation_types.TensorType(tf.bfloat16.as_numpy_dtype), - result=computation_types.TensorType(tf.bfloat16.as_numpy_dtype), + federated_language.FunctionType( + parameter=federated_language.TensorType(tf.bfloat16.as_numpy_dtype), + result=federated_language.TensorType(tf.bfloat16.as_numpy_dtype), ), ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization.py index 9234ef4df3..d0590c4dc6 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization.py @@ -16,23 +16,20 @@ import inspect from typing import Optional +import federated_language +from federated_language.proto import computation_pb2 as pb import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation_context from tensorflow_federated.python.core.environments.tensorflow_frontend import variable_utils -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization def serialize_py_fn_as_tf_computation( fn, - parameter_type: Optional[computation_types.Type], + parameter_type: Optional[federated_language.Type], context_stack, ): """Serializes a TF computation with a given parameter type. @@ -50,7 +47,7 @@ def serialize_py_fn_as_tf_computation( referenced from here). parameter_type: The parameter type specification if the fn accepts a parameter, or `None` if the fn doesn't declare any parameters. Either an - instance of `computation_types.Type`. + instance of `federated_language.Type`. context_stack: The context stack to use. Returns: @@ -69,9 +66,11 @@ def serialize_py_fn_as_tf_computation( # Document all accepted forms with examples in the API, and point to there # from here. - py_typecheck.check_type(context_stack, context_stack_base.ContextStack) + py_typecheck.check_type( + context_stack, federated_language.framework.ContextStack + ) if parameter_type is not None: - py_typecheck.check_type(parameter_type, computation_types.Type) + py_typecheck.check_type(parameter_type, federated_language.Type) signature = inspect.signature(fn) with tf.Graph().as_default() as graph: @@ -109,7 +108,7 @@ def serialize_py_fn_as_tf_computation( else: result = fn() if result is None: - raise computation_wrapper.ComputationReturnedNoneError(fn) + raise federated_language.framework.ComputationReturnedNoneError(fn) initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. @@ -144,7 +143,7 @@ def serialize_py_fn_as_tf_computation( result, graph ) - type_signature = computation_types.FunctionType(parameter_type, result_type) + type_signature = federated_language.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), @@ -155,7 +154,7 @@ def serialize_py_fn_as_tf_computation( ) return ( pb.Computation( - type=type_serialization.serialize_type(type_signature), + type=federated_language.framework.serialize_type(type_signature), tensorflow=tensorflow, ), type_signature, diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization_test.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization_test.py index 2c895b321e..bf26791fb5 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_serialization_test.py @@ -15,15 +15,13 @@ import collections from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_serialization -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization class TensorFlowSerializationTest(tf.test.TestCase): @@ -33,11 +31,13 @@ def test_serialize_tensorflow_with_no_parameter(self): tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, - context_stack_impl.context_stack, + federated_language.framework.global_context_stack, ) ) self.assertEqual( - type_serialization.deserialize_type(comp.type).compact_representation(), + federated_language.framework.deserialize_type( + comp.type + ).compact_representation(), '( -> int32)', ) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') @@ -64,12 +64,14 @@ def table_lookup(word): comp, extra_type_spec = ( tensorflow_serialization.serialize_py_fn_as_tf_computation( table_lookup, - computation_types.TensorType(np.str_, (None,)), - context_stack_impl.context_stack, + federated_language.TensorType(np.str_, (None,)), + federated_language.framework.global_context_stack, ) ) self.assertEqual( - type_serialization.deserialize_type(comp.type).compact_representation(), + federated_language.framework.deserialize_type( + comp.type + ).compact_representation(), '(str[?] -> int64[?])', ) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') @@ -97,12 +99,14 @@ def test_serialize_tensorflow_with_simple_add_three_lambda(self): comp, extra_type_spec = ( tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, - computation_types.TensorType(np.int32), - context_stack_impl.context_stack, + federated_language.TensorType(np.int32), + federated_language.framework.global_context_stack, ) ) self.assertEqual( - type_serialization.deserialize_type(comp.type).compact_representation(), + federated_language.framework.deserialize_type( + comp.type + ).compact_representation(), '(int32 -> int32)', ) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') @@ -126,14 +130,16 @@ def test_serialize_tensorflow_with_structured_type_signature(self): comp, extra_type_spec = ( tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [('x', np.int32), ('y', (np.float32, [2]))], batch_type ), - context_stack_impl.context_stack, + federated_language.framework.global_context_stack, ) ) self.assertEqual( - type_serialization.deserialize_type(comp.type).compact_representation(), + federated_language.framework.deserialize_type( + comp.type + ).compact_representation(), '( -> )', ) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') @@ -142,11 +148,11 @@ def test_serialize_tensorflow_with_structured_type_signature(self): '( -> )', ) self.assertIsInstance( - extra_type_spec.parameter, computation_types.StructWithPythonType + extra_type_spec.parameter, federated_language.StructWithPythonType ) self.assertIs(extra_type_spec.parameter.python_container, batch_type) self.assertIsInstance( - extra_type_spec.result, computation_types.StructWithPythonType + extra_type_spec.result, federated_language.StructWithPythonType ) self.assertIs(extra_type_spec.result.python_container, output_type) @@ -158,12 +164,14 @@ def _legacy_dataset_reducer_example(ds): comp, extra_type_spec = ( tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, - computation_types.SequenceType(np.int64), - context_stack_impl.context_stack, + federated_language.SequenceType(np.int64), + federated_language.framework.global_context_stack, ) ) self.assertEqual( - type_serialization.deserialize_type(comp.type).compact_representation(), + federated_language.framework.deserialize_type( + comp.type + ).compact_representation(), '(int64* -> int64)', ) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') @@ -193,12 +201,14 @@ def _legacy_dataset_reducer_example(ds): comp, extra_type_spec = ( tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, - computation_types.SequenceType([np.int64]), - context_stack_impl.context_stack, + federated_language.SequenceType([np.int64]), + federated_language.framework.global_context_stack, ) ) self.assertEqual( - type_serialization.deserialize_type(comp.type).compact_representation(), + federated_language.framework.deserialize_type( + comp.type + ).compact_representation(), '(* -> int64)', ) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types.py index 35e4a8f708..5312041c91 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types.py @@ -13,13 +13,11 @@ # limitations under the License. """Defines functions and classes for building and manipulating TFF types.""" +import federated_language import numpy as np import tensorflow as tf import tree -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types - def _tensorflow_dtype_to_numpy_dtype( dtype: tf.dtypes.DType, @@ -36,7 +34,7 @@ def _tensorflow_dtype_to_numpy_dtype( def _tensor_shape_to_array_shape( tensor_shape: tf.TensorShape, -) -> array_shape.ArrayShape: +) -> federated_language.ArrayShape: """Returns a `tff.types.ArrayShape` for the `tensor_shape`.""" if tensor_shape.rank is not None: shape = tensor_shape.as_list() @@ -45,22 +43,22 @@ def _tensor_shape_to_array_shape( return shape -def _tensor_spec_to_type(tensor_spec: tf.TensorSpec) -> computation_types.Type: +def _tensor_spec_to_type(tensor_spec: tf.TensorSpec) -> federated_language.Type: """Returns a `tff.Type` for the `tensor_spec`.""" dtype = _tensorflow_dtype_to_numpy_dtype(tensor_spec.dtype) shape = _tensor_shape_to_array_shape(tensor_spec.shape) - return computation_types.TensorType(dtype, shape) + return federated_language.TensorType(dtype, shape) def _dataset_spec_to_type( dataset_spec: tf.data.DatasetSpec, -) -> computation_types.Type: +) -> federated_language.Type: """Returns a `tff.Type` for the `dataset_spec`.""" element_type = to_type(dataset_spec.element_spec) - return computation_types.SequenceType(element_type) + return federated_language.SequenceType(element_type) -def to_type(obj: object) -> computation_types.Type: +def to_type(obj: object) -> federated_language.Type: """Returns a `tff.Type` for an `obj` containing TensorFlow type specs. This function extends `tff.types.to_type` to handle TensorFlow type specs and @@ -108,4 +106,4 @@ def _to_type(obj): return None partial_type = tree.traverse(_to_type, obj) - return computation_types.to_type(partial_type) + return federated_language.to_type(partial_type) diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types_test.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types_test.py index bb7d576e16..67ce71010d 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_types_test.py @@ -16,11 +16,11 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types class TensorflowToTypeTest(parameterized.TestCase): @@ -29,14 +29,14 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'dtype', tf.int32, - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ( 'dtype_nested', [tf.int32], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ], list, ), @@ -44,10 +44,10 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'dtype_mixed', [tf.int32, np.float32], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.int32), + federated_language.TensorType(np.float32), ], list, ), @@ -55,39 +55,39 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'tensor_like_shape_fully_defined', (tf.int32, tf.TensorShape([2, 3])), - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ), ( 'tensor_like_shape_partially_defined', (tf.int32, tf.TensorShape([2, None])), - computation_types.TensorType(np.int32, shape=[2, None]), + federated_language.TensorType(np.int32, shape=[2, None]), ), ( 'tensor_like_shape_unknown', (tf.int32, tf.TensorShape(None)), - computation_types.TensorType(np.int32, shape=None), + federated_language.TensorType(np.int32, shape=None), ), ( 'tensor_like_shape_scalar', (tf.int32, tf.TensorShape([])), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ( 'tensor_like_dtype_only', (tf.int32, [2, 3]), - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ), ( 'tensor_like_shape_only', (np.int32, tf.TensorShape([2, 3])), - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ), ( 'tensor_like_nested', [(tf.int32, tf.TensorShape([2, 3]))], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ], list, ), @@ -95,10 +95,10 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'tensor_like_mixed', [(tf.int32, tf.TensorShape([2, 3])), np.float32], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.float32), ], list, ), @@ -106,14 +106,14 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'tensor_spec', tf.TensorSpec(shape=[2, 3], dtype=tf.int32), - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ), ( 'tensor_spec_nested', [tf.TensorSpec(shape=[2, 3], dtype=tf.int32)], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.int32, shape=[2, 3]), ], list, ), @@ -121,10 +121,10 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'tensor_spec_mixed', [tf.TensorSpec(shape=[2, 3], dtype=tf.int32), np.float32], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32, shape=[2, 3]), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.int32, shape=[2, 3]), + federated_language.TensorType(np.float32), ], list, ), @@ -132,8 +132,8 @@ class TensorflowToTypeTest(parameterized.TestCase): ( 'dataset_spec', tf.data.DatasetSpec(tf.TensorSpec(shape=[2, 3], dtype=tf.int32)), - computation_types.SequenceType( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.SequenceType( + federated_language.TensorType(np.int32, shape=[2, 3]) ), ), ( @@ -141,10 +141,10 @@ class TensorflowToTypeTest(parameterized.TestCase): [ tf.data.DatasetSpec(tf.TensorSpec(shape=[2, 3], dtype=tf.int32)), ], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.SequenceType( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.SequenceType( + federated_language.TensorType(np.int32, shape=[2, 3]) ), ], list, @@ -156,12 +156,12 @@ class TensorflowToTypeTest(parameterized.TestCase): tf.data.DatasetSpec(tf.TensorSpec(shape=[2, 3], dtype=tf.int32)), np.float32, ], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.SequenceType( - computation_types.TensorType(np.int32, shape=[2, 3]) + federated_language.SequenceType( + federated_language.TensorType(np.int32, shape=[2, 3]) ), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), ], list, ), @@ -172,7 +172,7 @@ def test_returns_result_with_tensorflow_obj(self, obj, expected_result): self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( - ('type', computation_types.TensorType(np.int32)), + ('type', federated_language.TensorType(np.int32)), ('dtype', np.int32), ('tensor_like', (np.int32, [2, 3])), ('sequence_unnamed', [np.float64, np.int32, np.str_]), @@ -182,7 +182,7 @@ def test_returns_result_with_tensorflow_obj(self, obj, expected_result): def test_delegates_result_with_obj(self, obj): with mock.patch.object( - computation_types, 'to_type', autospec=True, spec_set=True + federated_language, 'to_type', autospec=True, spec_set=True ) as mock_to_type: tensorflow_types.to_type(obj) mock_to_type.assert_called_once_with(obj) diff --git a/tensorflow_federated/python/core/environments/xla_backend/BUILD b/tensorflow_federated/python/core/environments/xla_backend/BUILD index 9427f6ea26..b3942851e9 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/BUILD +++ b/tensorflow_federated/python/core/environments/xla_backend/BUILD @@ -45,12 +45,11 @@ py_library( name = "xla_serialization", srcs = ["xla_serialization.py"], deps = [ - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", "@com_google_protobuf//:protobuf_python", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -60,9 +59,8 @@ py_test( srcs = ["xla_serialization_test.py"], deps = [ ":xla_serialization", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", "@com_google_protobuf//:protobuf_python", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) diff --git a/tensorflow_federated/python/core/environments/xla_backend/xla_serialization.py b/tensorflow_federated/python/core/environments/xla_backend/xla_serialization.py index 4491177f89..800983ef9b 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/xla_serialization.py +++ b/tensorflow_federated/python/core/environments/xla_backend/xla_serialization.py @@ -16,15 +16,14 @@ from collections.abc import Sequence from typing import Optional, TypeVar, Union +import federated_language +from federated_language.proto import computation_pb2 as pb from jax.lib import xla_client import numpy as np from google.protobuf import any_pb2 -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization _HLO_MODULE_PROTO_URI = 'type.googleapis.com/xla.HloModuleProto' @@ -73,7 +72,7 @@ def unpack_xla_computation(any_pb: any_pb2.Any) -> xla_client.XlaComputation: def _make_xla_binding_for_type( - tensor_indexes: Sequence[int], type_spec: Optional[computation_types.Type] + tensor_indexes: Sequence[int], type_spec: Optional[federated_language.Type] ) -> Optional[pb.Xla.Binding]: """Generates an XLA binding for TFF type `type_spec`. @@ -84,7 +83,7 @@ def _make_xla_binding_for_type( tensor_indexes: The list of tensor indexes to use in the binding, in the order matching the order of flattened `type_spec`. type_spec: The type to generate the binding for. Must be either an instance - of `computation_types.Type`, or `None`. + of `federated_language.Type`, or `None`. Returns: The generated binding (either `pb.Xla.Binding` or `None`). @@ -92,13 +91,13 @@ def _make_xla_binding_for_type( if type_spec is None: return None - py_typecheck.check_type(type_spec, computation_types.Type) + py_typecheck.check_type(type_spec, federated_language.Type) py_typecheck.check_type(tensor_indexes, Sequence) def _make_starting_at_index( - type_spec: computation_types.Type, idx: int + type_spec: federated_language.Type, idx: int ) -> tuple[pb.Xla.Binding, int]: - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): return ( pb.Xla.Binding( tensor=pb.Xla.TensorBinding(index=tensor_indexes[idx]) @@ -106,7 +105,7 @@ def _make_starting_at_index( idx + 1, ) - if isinstance(type_spec, computation_types.StructType): + if isinstance(type_spec, federated_language.StructType): elements = [] for _, v in structure.iter_elements(type_spec): binding, idx = _make_starting_at_index(v, idx) @@ -123,10 +122,10 @@ def _make_starting_at_index( _T = TypeVar( '_T', - computation_types.TensorType, - computation_types.StructType, - computation_types.StructWithPythonType, - computation_types.FunctionType, + federated_language.TensorType, + federated_language.StructType, + federated_language.StructWithPythonType, + federated_language.FunctionType, ) @@ -134,7 +133,7 @@ def _remove_struct_element_names_from_tff_type(type_spec: _T) -> _T: """Removes names of struct elements from `type_spec`. Args: - type_spec: An instance of `computation_types.Type` that must be a tensor, a + type_spec: An instance of `federated_language.Type` that must be a tensor, a (possibly) nested structure of tensors, or a function. Returns: @@ -145,26 +144,24 @@ def _remove_struct_element_names_from_tff_type(type_spec: _T) -> _T: """ if type_spec is None: return None - if isinstance(type_spec, computation_types.FunctionType): - return computation_types.FunctionType( + if isinstance(type_spec, federated_language.FunctionType): + return federated_language.FunctionType( _remove_struct_element_names_from_tff_type(type_spec.parameter), # pytype: disable=wrong-arg-types _remove_struct_element_names_from_tff_type(type_spec.result), # pytype: disable=wrong-arg-types ) - if isinstance(type_spec, computation_types.TensorType): + if isinstance(type_spec, federated_language.TensorType): return type_spec - py_typecheck.check_type(type_spec, computation_types.StructType) - return computation_types.StructType( - [ - (None, _remove_struct_element_names_from_tff_type(v)) - for _, v in structure.iter_elements(type_spec) - ] - ) + py_typecheck.check_type(type_spec, federated_language.StructType) + return federated_language.StructType([ + (None, _remove_struct_element_names_from_tff_type(v)) + for _, v in structure.iter_elements(type_spec) + ]) def create_xla_tff_computation( xla_computation: xla_client.XlaComputation, tensor_indexes: Sequence[int], - type_spec: computation_types.FunctionType, + type_spec: federated_language.FunctionType, ) -> pb.Computation: """Creates an XLA TFF computation. @@ -183,7 +180,7 @@ def create_xla_tff_computation( """ py_typecheck.check_type(xla_computation, xla_client.XlaComputation) py_typecheck.check_type(tensor_indexes, Sequence) - py_typecheck.check_type(type_spec, computation_types.FunctionType) + py_typecheck.check_type(type_spec, federated_language.FunctionType) parameter_binding = _make_xla_binding_for_type( tensor_indexes, type_spec.parameter ) @@ -193,7 +190,7 @@ def create_xla_tff_computation( reconstructed_type = xla_computation_and_bindings_to_tff_type( xla_computation, parameter_binding, result_binding ) - py_typecheck.check_type(reconstructed_type, computation_types.FunctionType) + py_typecheck.check_type(reconstructed_type, federated_language.FunctionType) expected_type = _remove_struct_element_names_from_tff_type(type_spec) if not reconstructed_type.is_equivalent_to(expected_type): raise ValueError( @@ -201,7 +198,7 @@ def create_xla_tff_computation( 'TFF type {}.'.format(str(reconstructed_type), str(expected_type)) ) return pb.Computation( - type=type_serialization.serialize_type(type_spec), + type=federated_language.framework.serialize_type(type_spec), xla=pb.Xla( hlo_module=pack_xla_computation(xla_computation), parameter=parameter_binding, @@ -214,7 +211,7 @@ def xla_computation_and_bindings_to_tff_type( xla_computation: xla_client.XlaComputation, parameter_binding: Optional[pb.Xla.Binding], result_binding: pb.Xla.Binding, -) -> computation_types.FunctionType: +) -> federated_language.FunctionType: """Constructs the TFF type from an `xla_client.XlaComputation` and bindings. NOTE: This is a helper function, primarily intended for use in checking the @@ -227,7 +224,7 @@ def xla_computation_and_bindings_to_tff_type( result_binding: An instance of `pb.Xla.Binding` for the result. Returns: - An instance of `computation_types.FunctionType`. + An instance of `federated_language.FunctionType`. """ py_typecheck.check_type(xla_computation, xla_client.XlaComputation) program_shape = xla_computation.program_shape() @@ -249,13 +246,13 @@ def xla_computation_and_bindings_to_tff_type( 'Failed to construct TFF type from result binding:' f'{program_shape.result_shape()=}, {result_binding=}' ) from e - return computation_types.FunctionType(parameter_type, result_type) + return federated_language.FunctionType(parameter_type, result_type) def xla_shapes_and_binding_to_tff_type( xla_shapes: Sequence[xla_client.Shape], binding: Optional[pb.Xla.Binding] ) -> Optional[ - Union[computation_types.TensorType, computation_types.StructType] + Union[federated_language.TensorType, federated_language.StructType] ]: """Constructs the TFF type from a list of `xla_client.Shape` and a binding. @@ -264,7 +261,7 @@ def xla_shapes_and_binding_to_tff_type( binding: An instance of `pb.Xla.Binding` (or `None` if there's none). Returns: - An instance of `computation_types.Type` (or `None`). + An instance of `federated_language.Type` (or `None`). """ py_typecheck.check_type(xla_shapes, Sequence) if binding is not None: @@ -277,7 +274,7 @@ def xla_shapes_and_binding_to_tff_type( def _get_type( binding: Optional[pb.Xla.Binding], ) -> Optional[ - Union[computation_types.TensorType, computation_types.StructType] + Union[federated_language.TensorType, federated_language.StructType] ]: if binding is None: return None @@ -293,11 +290,11 @@ def _get_type( raise ValueError(f'Duplicate bindings referring to {index=}') unused_shape_indexes.remove(index) shape = tensor_shapes[index] - return computation_types.TensorType( + return federated_language.TensorType( shape.numpy_dtype(), shape.dimensions() ) if kind == 'struct': - return computation_types.StructType( + return federated_language.StructType( [(None, _get_type(x)) for x in binding.struct.element] ) if kind is None: diff --git a/tensorflow_federated/python/core/environments/xla_backend/xla_serialization_test.py b/tensorflow_federated/python/core/environments/xla_backend/xla_serialization_test.py index 4c85629d1f..b065b1e207 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/xla_serialization_test.py +++ b/tensorflow_federated/python/core/environments/xla_backend/xla_serialization_test.py @@ -13,15 +13,14 @@ # limitations under the License. from absl.testing import absltest +import federated_language +from federated_language.proto import computation_pb2 as pb import jax import jax.numpy as jnp import numpy as np from google.protobuf import any_pb2 -from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.core.environments.xla_backend import xla_serialization -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization def _make_xla_shape(shapes_and_dtypes_pytree): @@ -68,11 +67,11 @@ def test_pack_unpack_xla_computation_roundtrip(self): def test_create_xla_tff_computation_noarg(self): xla_comp = _make_test_xla_comp_noarg_to_int32() comp_pb = xla_serialization.create_xla_tff_computation( - xla_comp, [], computation_types.FunctionType(None, np.int32) + xla_comp, [], federated_language.FunctionType(None, np.int32) ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) + type_spec = federated_language.framework.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '( -> int32)') xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) self.assertIn( @@ -85,14 +84,14 @@ def test_create_xla_tff_computation_raises_missing_arg_in_xla(self): xla_comp = _make_test_xla_comp_noarg_to_int32() with self.assertRaises(ValueError): xla_serialization.create_xla_tff_computation( - xla_comp, [0], computation_types.FunctionType(np.int32, np.int32) + xla_comp, [0], federated_language.FunctionType(np.int32, np.int32) ) def test_create_xla_tff_computation_raises_missing_arg_in_type_spec(self): xla_comp = _make_test_xla_comp_int32x10_to_int32x10() with self.assertRaises(ValueError): xla_serialization.create_xla_tff_computation( - xla_comp, [], computation_types.FunctionType(None, np.int32) + xla_comp, [], federated_language.FunctionType(None, np.int32) ) def test_create_xla_tff_computation_raises_arg_type_mismatch(self): @@ -101,7 +100,7 @@ def test_create_xla_tff_computation_raises_arg_type_mismatch(self): xla_serialization.create_xla_tff_computation( xla_comp, [0], - computation_types.FunctionType(np.int32, (np.int32, (10,))), + federated_language.FunctionType(np.int32, (np.int32, (10,))), ) def test_create_xla_tff_computation_raises_result_type_mismatch(self): @@ -110,7 +109,7 @@ def test_create_xla_tff_computation_raises_result_type_mismatch(self): xla_serialization.create_xla_tff_computation( xla_comp, [0], - computation_types.FunctionType((np.int32, (10,)), np.int32), + federated_language.FunctionType((np.int32, (10,)), np.int32), ) def test_create_xla_tff_computation_int32x10_to_int32x10(self): @@ -118,11 +117,11 @@ def test_create_xla_tff_computation_int32x10_to_int32x10(self): comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0], - computation_types.FunctionType((np.int32, (10,)), (np.int32, (10,))), + federated_language.FunctionType((np.int32, (10,)), (np.int32, (10,))), ) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') - type_spec = type_serialization.deserialize_type(comp_pb.type) + type_spec = federated_language.framework.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32[10] -> int32[10])') def test_create_xla_tff_computation_with_reordered_tensor_indexes(self): @@ -138,28 +137,28 @@ def dot(x, y): comp_pb_1 = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], - computation_types.FunctionType( + federated_language.FunctionType( ((np.int32, (10, 1)), (np.int32, (1, 20))), (np.int32, (10, 20)), ), ) self.assertIsInstance(comp_pb_1, pb.Computation) self.assertEqual(comp_pb_1.WhichOneof('computation'), 'xla') - type_spec_1 = type_serialization.deserialize_type(comp_pb_1.type) + type_spec_1 = federated_language.framework.deserialize_type(comp_pb_1.type) self.assertEqual( str(type_spec_1), '( -> int32[10,20])' ) comp_pb_2 = xla_serialization.create_xla_tff_computation( xla_comp, [1, 0], - computation_types.FunctionType( + federated_language.FunctionType( ((np.int32, (1, 20)), (np.int32, (10, 1))), (np.int32, (10, 20)), ), ) self.assertIsInstance(comp_pb_2, pb.Computation) self.assertEqual(comp_pb_2.WhichOneof('computation'), 'xla') - type_spec_2 = type_serialization.deserialize_type(comp_pb_2.type) + type_spec_2 = federated_language.framework.deserialize_type(comp_pb_2.type) self.assertEqual( str(type_spec_2), '( -> int32[10,20])' ) diff --git a/tensorflow_federated/python/core/framework/BUILD b/tensorflow_federated/python/core/framework/BUILD index 804ec0477d..001d9b64d7 100644 --- a/tensorflow_federated/python/core/framework/BUILD +++ b/tensorflow_federated/python/core/framework/BUILD @@ -17,30 +17,14 @@ py_library( srcs = ["__init__.py"], visibility = ["//tensorflow_federated:__pkg__"], deps = [ - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/compiler:transformation_utils", "//tensorflow_federated/python/core/impl/compiler:transformations", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:computation_serialization", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/context_stack:get_context_stack", - "//tensorflow_federated/python/core/impl/context_stack:set_default_context", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", "//tensorflow_federated/python/core/impl/execution_contexts:mergeable_comp_execution_context", - "//tensorflow_federated/python/core/impl/execution_contexts:sync_execution_context", "//tensorflow_federated/python/core/impl/executor_stacks:executor_factory", "//tensorflow_federated/python/core/impl/executor_stacks:python_executor_stacks", - "//tensorflow_federated/python/core/impl/executors:executor_base", - "//tensorflow_federated/python/core/impl/executors:executor_factory", - "//tensorflow_federated/python/core/impl/executors:executors_errors", "//tensorflow_federated/python/core/impl/executors:remote_executor", "//tensorflow_federated/python/core/impl/executors:remote_executor_grpc_stub", "//tensorflow_federated/python/core/impl/executors:remote_executor_stub", "//tensorflow_federated/python/core/impl/executors:value_serialization", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/framework/__init__.py b/tensorflow_federated/python/core/framework/__init__.py index d1d224ea2e..1243aa6491 100644 --- a/tensorflow_federated/python/core/framework/__init__.py +++ b/tensorflow_federated/python/core/framework/__init__.py @@ -13,60 +13,71 @@ # limitations under the License. """Libraries for extending the TensorFlow Federated core library.""" +import federated_language # pylint: disable=g-importing-member -from tensorflow_federated.python.core.impl.compiler.building_block_factory import unique_name_generator -from tensorflow_federated.python.core.impl.compiler.building_blocks import Block -from tensorflow_federated.python.core.impl.compiler.building_blocks import Call -from tensorflow_federated.python.core.impl.compiler.building_blocks import CompiledComputation -from tensorflow_federated.python.core.impl.compiler.building_blocks import ComputationBuildingBlock -from tensorflow_federated.python.core.impl.compiler.building_blocks import Data -from tensorflow_federated.python.core.impl.compiler.building_blocks import Intrinsic -from tensorflow_federated.python.core.impl.compiler.building_blocks import Lambda -from tensorflow_federated.python.core.impl.compiler.building_blocks import Literal -from tensorflow_federated.python.core.impl.compiler.building_blocks import Placement -from tensorflow_federated.python.core.impl.compiler.building_blocks import Reference -from tensorflow_federated.python.core.impl.compiler.building_blocks import Selection -from tensorflow_federated.python.core.impl.compiler.building_blocks import Struct -from tensorflow_federated.python.core.impl.compiler.building_blocks import UnexpectedBlockError -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_AGGREGATE -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_APPLY -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_BROADCAST -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_EVAL_AT_CLIENTS -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_EVAL_AT_SERVER -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_MAP -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_MAP_ALL_EQUAL -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_SUM -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_VALUE_AT_CLIENTS -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_VALUE_AT_SERVER -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_ZIP_AT_CLIENTS -from tensorflow_federated.python.core.impl.compiler.intrinsic_defs import FEDERATED_ZIP_AT_SERVER -from tensorflow_federated.python.core.impl.compiler.transformation_utils import transform_postorder -from tensorflow_federated.python.core.impl.compiler.transformation_utils import transform_preorder +unique_name_generator = federated_language.framework.unique_name_generator +Block = federated_language.framework.Block +Call = federated_language.framework.Call +CompiledComputation = federated_language.framework.CompiledComputation +ComputationBuildingBlock = federated_language.framework.ComputationBuildingBlock +Data = federated_language.framework.Data +Intrinsic = federated_language.framework.Intrinsic +Lambda = federated_language.framework.Lambda +Literal = federated_language.framework.Literal +Placement = federated_language.framework.Placement +Reference = federated_language.framework.Reference +Selection = federated_language.framework.Selection +Struct = federated_language.framework.Struct +UnexpectedBlockError = federated_language.framework.UnexpectedBlockError +FEDERATED_AGGREGATE = federated_language.framework.FEDERATED_AGGREGATE +FEDERATED_APPLY = federated_language.framework.FEDERATED_APPLY +FEDERATED_BROADCAST = federated_language.framework.FEDERATED_BROADCAST +FEDERATED_EVAL_AT_CLIENTS = ( + federated_language.framework.FEDERATED_EVAL_AT_CLIENTS +) +FEDERATED_EVAL_AT_SERVER = federated_language.framework.FEDERATED_EVAL_AT_SERVER +FEDERATED_MAP = federated_language.framework.FEDERATED_MAP +FEDERATED_MAP_ALL_EQUAL = federated_language.framework.FEDERATED_MAP_ALL_EQUAL +FEDERATED_SUM = federated_language.framework.FEDERATED_SUM +FEDERATED_VALUE_AT_CLIENTS = ( + federated_language.framework.FEDERATED_VALUE_AT_CLIENTS +) +FEDERATED_VALUE_AT_SERVER = ( + federated_language.framework.FEDERATED_VALUE_AT_SERVER +) +FEDERATED_ZIP_AT_CLIENTS = federated_language.framework.FEDERATED_ZIP_AT_CLIENTS +FEDERATED_ZIP_AT_SERVER = federated_language.framework.FEDERATED_ZIP_AT_SERVER +transform_postorder = federated_language.framework.transform_postorder +transform_preorder = federated_language.framework.transform_preorder from tensorflow_federated.python.core.impl.compiler.transformations import to_call_dominant -from tensorflow_federated.python.core.impl.computation.computation_impl import ConcreteComputation -from tensorflow_federated.python.core.impl.computation.computation_serialization import deserialize_computation -from tensorflow_federated.python.core.impl.computation.computation_serialization import serialize_computation -from tensorflow_federated.python.core.impl.computation.function_utils import pack_args_into_struct -from tensorflow_federated.python.core.impl.computation.function_utils import unpack_args_from_struct -from tensorflow_federated.python.core.impl.context_stack.context_base import AsyncContext -from tensorflow_federated.python.core.impl.context_stack.context_base import SyncContext -from tensorflow_federated.python.core.impl.context_stack.context_stack_base import ContextStack -from tensorflow_federated.python.core.impl.context_stack.get_context_stack import get_context_stack -from tensorflow_federated.python.core.impl.context_stack.set_default_context import set_default_context -from tensorflow_federated.python.core.impl.execution_contexts.async_execution_context import AsyncExecutionContext + +ConcreteComputation = federated_language.framework.ConcreteComputation +deserialize_computation = federated_language.framework.deserialize_computation +serialize_computation = federated_language.framework.serialize_computation +pack_args_into_struct = federated_language.framework.pack_args_into_struct +unpack_args_from_struct = federated_language.framework.unpack_args_from_struct +AsyncContext = federated_language.framework.AsyncContext +SyncContext = federated_language.framework.SyncContext +ContextStack = federated_language.framework.ContextStack +get_context_stack = federated_language.framework.get_context_stack +set_default_context = federated_language.framework.set_default_context +AsyncExecutionContext = federated_language.framework.AsyncExecutionContext from tensorflow_federated.python.core.impl.execution_contexts.mergeable_comp_execution_context import MergeableCompExecutionContext from tensorflow_federated.python.core.impl.execution_contexts.mergeable_comp_execution_context import MergeableCompForm -from tensorflow_federated.python.core.impl.execution_contexts.sync_execution_context import SyncExecutionContext + +SyncExecutionContext = federated_language.framework.SyncExecutionContext from tensorflow_federated.python.core.impl.executor_stacks.executor_factory import local_cpp_executor_factory from tensorflow_federated.python.core.impl.executor_stacks.python_executor_stacks import ResourceManagingExecutorFactory -from tensorflow_federated.python.core.impl.executors.executor_base import Executor -from tensorflow_federated.python.core.impl.executors.executor_factory import CardinalitiesType -from tensorflow_federated.python.core.impl.executors.executor_factory import ExecutorFactory -from tensorflow_federated.python.core.impl.executors.executors_errors import RetryableError + +Executor = federated_language.framework.Executor +CardinalitiesType = federated_language.framework.CardinalitiesType +ExecutorFactory = federated_language.framework.ExecutorFactory +RetryableError = federated_language.framework.RetryableError from tensorflow_federated.python.core.impl.executors.remote_executor import RemoteExecutor from tensorflow_federated.python.core.impl.executors.remote_executor_grpc_stub import RemoteExecutorGrpcStub from tensorflow_federated.python.core.impl.executors.remote_executor_stub import RemoteExecutorStub from tensorflow_federated.python.core.impl.executors.value_serialization import deserialize_value from tensorflow_federated.python.core.impl.executors.value_serialization import serialize_value -from tensorflow_federated.python.core.impl.types.placements import PlacementLiteral + +PlacementLiteral = federated_language.framework.PlacementLiteral # pylint: enable=g-importing-member diff --git a/tensorflow_federated/python/core/impl/BUILD b/tensorflow_federated/python/core/impl/BUILD index 182079a322..9382aa0f96 100644 --- a/tensorflow_federated/python/core/impl/BUILD +++ b/tensorflow_federated/python/core/impl/BUILD @@ -19,7 +19,6 @@ package_group( "//tensorflow_federated/python/core/environments:environments_packages", "//tensorflow_federated/python/core/framework:framework_packages", "//tensorflow_federated/python/core/templates:templates_packages", - "//tensorflow_federated/python/core/test:test_packages", "//tensorflow_federated/python/learning:learning_packages", "//tensorflow_federated/python/program:program_packages", "//tensorflow_federated/python/simulation:simulation_packages", diff --git a/tensorflow_federated/python/core/impl/compiler/BUILD b/tensorflow_federated/python/core/impl/compiler/BUILD index 552f440d08..6d572aa21f 100644 --- a/tensorflow_federated/python/core/impl/compiler/BUILD +++ b/tensorflow_federated/python/core/impl/compiler/BUILD @@ -5,10 +5,8 @@ package( default_visibility = [ ":compiler_packages", "//tensorflow_federated/python/core/impl:impl_users", - "//tensorflow_federated/python/core/impl/computation:computation_packages", "//tensorflow_federated/python/core/impl/execution_contexts:execution_contexts_packages", "//tensorflow_federated/python/core/impl/executors:executors_packages", - "//tensorflow_federated/python/core/impl/federated_context:federated_context_packages", ], ) @@ -25,212 +23,24 @@ py_library( visibility = ["//tools/python_package:python_package_tool"], ) -py_library( - name = "array", - srcs = ["array.py"], - deps = [ - "//tensorflow_federated/proto/v0:array_py_pb2", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:dtype_utils", - ], -) - -py_test( - name = "array_test", - srcs = ["array_test.py"], - deps = [ - ":array", - "//tensorflow_federated/proto/v0:array_py_pb2", - "//tensorflow_federated/proto/v0:data_type_py_pb2", - ], -) - -py_library( - name = "building_block_analysis", - srcs = ["building_block_analysis.py"], - deps = [":building_blocks"], -) - -py_library( - name = "building_block_factory", - srcs = ["building_block_factory.py"], - deps = [ - ":building_blocks", - ":intrinsic_defs", - ":transformation_utils", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:type_transformations", - ], -) - -py_test( - name = "building_block_factory_test", - size = "large", - srcs = ["building_block_factory_test.py"], - args = [ - "--golden", - "$(location building_block_factory_test_goldens/constructs_correct_computation_clients.expected)", - "--golden", - "$(location building_block_factory_test_goldens/constructs_correct_computation_server.expected)", - "--golden", - "$(location building_block_factory_test_goldens/replaces_single_element.expected)", - "--golden", - "$(location building_block_factory_test_goldens/skips_unnamed_element.expected)", - "--golden", - "$(location building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected)", - "--golden", - "$(location building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected)", - "--golden", - "$(location building_block_factory_test_goldens/tuple_federated_map_with_two_values_different_typed.expected)", - "--golden", - "$(location building_block_factory_test_goldens/tuple_federated_apply_with_two_values_unnamed.expected)", - "--golden", - "$(location building_block_factory_test_goldens/tuple_federated_apply_with_two_values_named.expected)", - "--golden", - "$(location building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected)", - "--golden", - "$(location building_block_factory_test_goldens/zips_tuple_unnamed.expected)", - "--golden", - "$(location building_block_factory_test_goldens/zips_tuple_named.expected)", - "--golden", - "$(location building_block_factory_test_goldens/zips_reference.expected)", - ], - data = [ - "building_block_factory_test_goldens/constructs_correct_computation_clients.expected", - "building_block_factory_test_goldens/constructs_correct_computation_server.expected", - "building_block_factory_test_goldens/replaces_single_element.expected", - "building_block_factory_test_goldens/skips_unnamed_element.expected", - "building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected", - "building_block_factory_test_goldens/tuple_federated_apply_with_two_values_named.expected", - "building_block_factory_test_goldens/tuple_federated_apply_with_two_values_unnamed.expected", - "building_block_factory_test_goldens/tuple_federated_map_with_two_values_different_typed.expected", - "building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected", - "building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected", - "building_block_factory_test_goldens/zips_reference.expected", - "building_block_factory_test_goldens/zips_tuple_named.expected", - "building_block_factory_test_goldens/zips_tuple_unnamed.expected", - ], - deps = [ - ":building_block_factory", - ":building_blocks", - "//tensorflow_federated/python/common_libs:golden", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_test_utils", - ], -) - py_library( name = "building_block_test_utils", testonly = True, srcs = ["building_block_test_utils.py"], deps = [ - ":array", - ":building_block_factory", - ":building_blocks", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "@com_google_protobuf//:protobuf_python", - ], -) - -py_library( - name = "building_blocks", - srcs = ["building_blocks.py"], - deps = [ - ":array", - ":intrinsic_defs", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:typed_object", - "@com_google_protobuf//:protobuf_python", - ], -) - -py_test( - name = "building_blocks_test", - size = "small", - srcs = ["building_blocks_test.py"], - deps = [ - ":array", - ":building_block_test_utils", - ":building_blocks", - ":computation_factory", - ":intrinsic_defs", - "//tensorflow_federated/proto/v0:array_py_pb2", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/proto/v0:data_type_py_pb2", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_serialization", "@com_google_protobuf//:protobuf_python", + "@federated_language//federated_language", ], ) -py_library( - name = "computation_factory", - srcs = ["computation_factory.py"], - deps = [ - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_factory", - "//tensorflow_federated/python/core/impl/types:type_serialization", - ], -) - -py_test( - name = "computation_factory_test", - srcs = ["computation_factory_test.py"], - deps = [ - ":computation_factory", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_factory", - "//tensorflow_federated/python/core/impl/types:type_serialization", - ], -) - -py_library( - name = "intrinsic_defs", - srcs = ["intrinsic_defs.py"], - deps = [ - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_factory", - ], -) - -py_test( - name = "intrinsic_defs_test", - size = "small", - srcs = ["intrinsic_defs_test.py"], - deps = [":intrinsic_defs"], -) - py_library( name = "transformations", srcs = ["transformations.py"], deps = [ - ":building_block_factory", - ":building_blocks", - ":intrinsic_defs", - ":transformation_utils", - ":tree_analysis", ":tree_transformations", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -238,75 +48,12 @@ py_test( name = "transformations_test", srcs = ["transformations_test.py"], deps = [ - ":building_block_factory", ":building_block_test_utils", - ":building_blocks", - ":intrinsic_defs", - ":transformation_utils", ":transformations", - ":tree_analysis", ":tree_transformations", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_test_utils", - ], -) - -py_library( - name = "transformation_utils", - srcs = ["transformation_utils.py"], - deps = [ - ":building_blocks", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_test( - name = "transformation_utils_test", - size = "small", - srcs = ["transformation_utils_test.py"], - deps = [ - ":building_block_test_utils", - ":building_blocks", - ":computation_factory", - ":transformation_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "tree_analysis", - srcs = ["tree_analysis.py"], - deps = [ - ":building_block_analysis", - ":building_blocks", - ":intrinsic_defs", - ":transformation_utils", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - ], -) - -py_test( - name = "tree_analysis_test", - srcs = ["tree_analysis_test.py"], - deps = [ - ":building_block_factory", - ":building_block_test_utils", - ":building_blocks", - ":intrinsic_defs", - ":tree_analysis", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -314,16 +61,8 @@ py_library( name = "tree_transformations", srcs = ["tree_transformations.py"], deps = [ - ":building_block_analysis", - ":building_block_factory", - ":building_blocks", - ":intrinsic_defs", - ":transformation_utils", - ":tree_analysis", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_transformations", + "@federated_language//federated_language", ], ) @@ -339,18 +78,11 @@ py_test( "tree_transformations_test_goldens/uniquify_names_blocks_nested_inside_of_locals.expected", ], deps = [ - ":building_block_factory", ":building_block_test_utils", - ":building_blocks", - ":intrinsic_defs", - ":transformation_utils", - ":tree_analysis", ":tree_transformations", "//tensorflow_federated/python/common_libs:golden", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -359,10 +91,9 @@ py_library( testonly = True, srcs = ["compiler_test_utils.py"], deps = [ - ":building_blocks", - ":transformation_utils", "//tensorflow_federated/python/common_libs:golden", "//tensorflow_federated/python/common_libs:py_typecheck", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/impl/compiler/array.py b/tensorflow_federated/python/core/impl/compiler/array.py deleted file mode 100644 index 38d36300f9..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/array.py +++ /dev/null @@ -1,423 +0,0 @@ -# Copyright 2024, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for working with arrays.""" - -from typing import Optional, Union - -import ml_dtypes -import numpy as np - -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import dtype_utils - -# Array is the Python representation of the `Array` protobuf, and is the native -# representation of an array. -Array = Union[ - # Python types - bool, - int, - float, - complex, - str, - bytes, - # Numpy types - np.generic, - np.ndarray, -] - - -def from_proto(array_pb: array_pb2.Array) -> Array: - """Returns an `Array` for the `array_pb`.""" - dtype = dtype_utils.from_proto(array_pb.dtype) - shape = array_shape.from_proto(array_pb.shape) - - if dtype is np.bool_: - value = array_pb.bool_list.value - elif dtype is np.int8: - value = array_pb.int8_list.value - elif dtype is np.int16: - value = array_pb.int16_list.value - elif dtype is np.int32: - value = array_pb.int32_list.value - elif dtype is np.int64: - value = array_pb.int64_list.value - elif dtype is np.uint8: - value = array_pb.uint8_list.value - elif dtype is np.uint16: - value = array_pb.uint16_list.value - elif dtype is np.uint32: - value = array_pb.uint32_list.value - elif dtype is np.uint64: - value = array_pb.uint64_list.value - elif dtype is np.float16: - value = array_pb.float16_list.value - # Values of dtype `np.float16` are packed to and unpacked from a protobuf - # field of type `int32` using the following logic in order to maintain - # compatibility with how other external environments (e.g., TensorFlow, JAX) - # represent values of `np.float16`. - value = np.asarray(value, np.uint16).view(np.float16).tolist() - elif dtype is np.float32: - value = array_pb.float32_list.value - elif dtype is np.float64: - value = array_pb.float64_list.value - elif dtype is np.complex64: - if len(array_pb.complex64_list.value) % 2 != 0: - raise ValueError( - 'Expected the number of complex values to be even, one real and one' - ' imaginary part for each complex value.' - ) - value = iter(array_pb.complex64_list.value) - value = [complex(real, imag) for real, imag in zip(value, value)] - elif dtype is np.complex128: - if len(array_pb.complex128_list.value) % 2 != 0: - raise ValueError( - 'Expected the number of complex values to be even, one real and one' - ' imaginary part for each complex value.' - ) - value = iter(array_pb.complex128_list.value) - value = [complex(real, imag) for real, imag in zip(value, value)] - elif dtype is ml_dtypes.bfloat16: - value = array_pb.bfloat16_list.value - # Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a - # protobuf field of type `int32` using the following logic in order to - # maintain compatibility with how other external environments (e.g., - # TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`. - value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist() - elif dtype is np.str_: - value = array_pb.string_list.value - else: - raise NotImplementedError(f'Unexpected `dtype` found: {dtype}.') - - # Strings are stored as bytes in `array_pb2.Array` and trailing null values - # are dropped when using `np.bytes_`, use `np.object_` instead. - if dtype is np.str_: - dtype = np.object_ - - # `Array` is a `Union` of native Python types and numpy types. However, the - # protobuf representation of `Array` contains additional information like - # dtype and shape. This information is lost when returning native Python types - # making it impossible to infer the original dtype later. Therefore, a numpy - # value should almost always be returned from this function. String values are - # an exception to this because it's not possible to represent null-terminated - # scalar strings using numpy and this is ok because string types can only be - # inferred as string types. - if not array_shape.is_shape_scalar(shape): - value = np.array(value, dtype).reshape(shape) - else: - (value,) = value - value = dtype(value) - - return value - - -def to_proto( - value: Array, *, dtype_hint: Optional[type[np.generic]] = None -) -> array_pb2.Array: - """Returns an `array_pb2.Array` for the `value`.""" - - if dtype_hint is not None: - if not dtype_utils.is_valid_dtype(dtype_hint): - raise ValueError( - f'Expected `dtype_hint` to be a valid dtype, found {dtype_hint}.' - ) - if not is_compatible_dtype(value, dtype_hint): - raise ValueError(f'Expected {value} to be compatible with {dtype_hint}.') - dtype = dtype_hint - else: - if isinstance(value, (np.ndarray, np.generic)): - dtype = value.dtype.type - # If the value has a dtype of `np.bytes_` or `np.object_`, the serialized - # dtype should still be a `np.str_`. - if np.issubdtype(dtype, np.bytes_) or np.issubdtype(dtype, np.object_): - dtype = np.str_ - else: - dtype = dtype_utils.infer_dtype(value) - - # Normalize to a numpy value; strings are stored as bytes in `array_pb2.Array` - # and trailing null values are dropped when using `np.bytes_`, so use - # `np.object_` instead. - if dtype is np.str_: - - def _contains_type(value, classinfo): - if isinstance(value, (np.ndarray, np.generic)): - if value.size == 0: - return False - item = value.item(0) - else: - item = value - return isinstance(item, classinfo) - - if _contains_type(value, str): - value = np.asarray(value, np.bytes_) - else: - value = np.asarray(value, np.object_) - else: - value = np.asarray(value, dtype) - - dtype_pb = dtype_utils.to_proto(dtype) - shape_pb = array_shape.to_proto(value.shape) - value = value.flatten().tolist() - - if dtype is np.bool_: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - bool_list=array_pb2.Array.BoolList(value=value), - ) - elif dtype is np.int8: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - int8_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.int16: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - int16_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.int32: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - int32_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.int64: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - int64_list=array_pb2.Array.Int64List(value=value), - ) - elif dtype is np.uint8: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - uint8_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.uint16: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - uint16_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.uint32: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - uint32_list=array_pb2.Array.Uint32List(value=value), - ) - elif dtype is np.uint64: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - uint64_list=array_pb2.Array.Uint64List(value=value), - ) - elif dtype is np.float16: - # Values of dtype `np.float16` are packed to and unpacked from a protobuf - # field of type `int32` using the following logic in order to maintain - # compatibility with how other external environments (e.g., TensorFlow, JAX) - # represent values of `np.float16`. - value = np.asarray(value, np.float16).view(np.uint16).tolist() - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - float16_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.float32: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - float32_list=array_pb2.Array.FloatList(value=value), - ) - elif dtype is np.float64: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - float64_list=array_pb2.Array.DoubleList(value=value), - ) - elif dtype is np.complex64: - packed_value = [] - for x in value: - if not isinstance(x, complex): - raise ValueError(f'Expected a complex type, found {type(x)}.') - packed_value.extend([x.real, x.imag]) - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - complex64_list=array_pb2.Array.FloatList(value=packed_value), - ) - elif dtype is np.complex128: - packed_value = [] - for x in value: - if not isinstance(x, complex): - raise ValueError(f'Expected a complex type, found {type(x)}.') - packed_value.extend([x.real, x.imag]) - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - complex128_list=array_pb2.Array.DoubleList(value=packed_value), - ) - elif dtype is ml_dtypes.bfloat16: - # Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a - # protobuf field of type `int32` using the following logic in order to - # maintain compatibility with how other external environments (e.g., - # TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`. - value = np.asarray(value, ml_dtypes.bfloat16).view(np.uint16).tolist() - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - bfloat16_list=array_pb2.Array.IntList(value=value), - ) - elif dtype is np.str_: - return array_pb2.Array( - dtype=dtype_pb, - shape=shape_pb, - string_list=array_pb2.Array.BytesList(value=value), - ) - else: - raise NotImplementedError(f'Unexpected `dtype` found: {dtype}.') - - -def from_proto_content(array_pb: array_pb2.Array) -> Array: - """Returns an `Array` for the `array_pb`.""" - dtype = dtype_utils.from_proto(array_pb.dtype) - shape = array_shape.from_proto(array_pb.shape) - - if dtype is not np.str_: - value = np.frombuffer(array_pb.content, dtype) - else: - raise NotImplementedError(f'Unexpected `dtype` found: {dtype}.') - - # `Array` is a `Union` of native Python types and numpy types. However, the - # protobuf representation of `Array` contains additional information like - # dtype and shape. This information is lost when returning native Python types - # making it impossible to infer the original dtype later. Therefore, a numpy - # value should almost always be returned from this function. String values are - # an exception to this because it's not possible to represent null-terminated - # scalar strings using numpy and this is ok because string types can only be - # inferred as string types. - if not array_shape.is_shape_scalar(shape): - value = value.reshape(shape) - else: - value = value.item() - value = dtype(value) - - return value - - -def to_proto_content( - value: Array, *, dtype_hint: Optional[type[np.generic]] = None -) -> array_pb2.Array: - """Returns an `Array` for the `value`.""" - - if dtype_hint is not None: - if not dtype_utils.is_valid_dtype(dtype_hint): - raise ValueError( - f'Expected `dtype_hint` to be a valid dtype, found {dtype_hint}.' - ) - if not is_compatible_dtype(value, dtype_hint): - raise ValueError(f'Expected {value} to be compatible with {dtype_hint}.') - dtype = dtype_hint - else: - if isinstance(value, (np.ndarray, np.generic)): - dtype = value.dtype.type - # If the value has a dtype of `np.bytes_` or `np.object_`, the serialized - # dtype should still be a `np.str_`. - if np.issubdtype(dtype, np.bytes_) or np.issubdtype(dtype, np.object_): - dtype = np.str_ - else: - dtype = dtype_utils.infer_dtype(value) - - # Normalize to a numpy value. - if dtype is not np.str_: - value = np.asarray(value, dtype) - else: - raise NotImplementedError(f'Unexpected `dtype` found: {dtype}.') - - dtype_pb = dtype_utils.to_proto(dtype) - shape_pb = array_shape.to_proto(value.shape) - content = value.tobytes() - - return array_pb2.Array(dtype=dtype_pb, shape=shape_pb, content=content) - - -def is_compatible_dtype(value: Array, dtype: type[np.generic]) -> bool: - """Returns `True` if `value` is compatible with `dtype`, otherwise `False`. - - This functions checks that the `value` has the same scalar kind (e.g. integer, - floating) and has a compatible size (e.g. 32-bits, 16-bits) as `dtype` . - - See https://numpy.org/doc/stable/reference/arrays.scalars.html for more - information. - - Args: - value: The value to check. - dtype: The dtype to check against. - """ - if isinstance(value, (np.ndarray, np.generic)): - value_dtype = value.dtype.type - else: - value_dtype = type(value) - if value_dtype is dtype: - return True - - # Check dtype kind. - if np.issubdtype(value_dtype, np.bool_): - # Skip checking dtype size, `np.bool_` does not have a size. - return dtype is np.bool_ - elif np.issubdtype(value_dtype, np.integer): - if not np.issubdtype(dtype, np.integer): - return False - elif np.issubdtype(value_dtype, np.floating): - if not np.issubdtype(dtype, np.floating): - return False - elif np.issubdtype(value_dtype, np.complexfloating): - if not np.issubdtype(dtype, np.complexfloating): - return False - elif np.issubdtype(value_dtype, np.character) or np.issubdtype( - value_dtype, np.object_ - ): - # Skip checking dtype size, `np.str_`, `np.bytes_`, and `np.object_` - # (null-terminated bytes) have a variable length. - return dtype is np.str_ - else: - return False - - # Check dtype size. - if isinstance(value, (np.ndarray, np.generic)): - # `np.can_cast` does not does not apply value-based logic to `np.ndarray` or - # numpy scalars (since version 2.0). Testing the `dtype` of the value rather - # the the value aligns how `np.ndarray` and `np.generic` types are handled - # across different versions of numpy. See - # https://numpy.org/doc/stable/reference/generated/numpy.can_cast.html for - # more information. - return np.can_cast(value.dtype, dtype) - elif isinstance(value, (int, float, complex)): - return dtype_utils.can_cast(value, dtype) - else: - return False - - -def is_compatible_shape(value: Array, shape: array_shape.ArrayShape) -> bool: - """Returns `True` if `value` is compatible with `shape`, otherwise `False`. - - Args: - value: The value to check. - shape: The `tff.types.ArrayShape` to check against. - """ - if isinstance(value, np.ndarray): - return array_shape.is_compatible_with(value.shape, shape) - else: - return array_shape.is_shape_scalar(shape) diff --git a/tensorflow_federated/python/core/impl/compiler/array_test.py b/tensorflow_federated/python/core/impl/compiler/array_test.py deleted file mode 100644 index 0f1dc708bd..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/array_test.py +++ /dev/null @@ -1,1575 +0,0 @@ -# Copyright 2023, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -from absl.testing import absltest -from absl.testing import parameterized -import ml_dtypes -import numpy as np - -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.proto.v0 import data_type_pb2 -from tensorflow_federated.python.core.impl.compiler import array - - -class FromProtoTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'bool', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BOOL, - shape=array_pb2.ArrayShape(dim=[]), - bool_list=array_pb2.Array.BoolList(value=[True]), - ), - np.bool_(True), - ), - ( - 'int8', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT8, - shape=array_pb2.ArrayShape(dim=[]), - int8_list=array_pb2.Array.IntList(value=[1]), - ), - np.int8(1), - ), - ( - 'int16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT16, - shape=array_pb2.ArrayShape(dim=[]), - int16_list=array_pb2.Array.IntList(value=[1]), - ), - np.int16(1), - ), - ( - 'int32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ), - np.int32(1), - ), - ( - 'int64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - int64_list=array_pb2.Array.Int64List(value=[1]), - ), - np.int64(1), - ), - ( - 'uint8', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT8, - shape=array_pb2.ArrayShape(dim=[]), - uint8_list=array_pb2.Array.IntList(value=[1]), - ), - np.uint8(1), - ), - ( - 'uint16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT16, - shape=array_pb2.ArrayShape(dim=[]), - uint16_list=array_pb2.Array.IntList(value=[1]), - ), - np.uint16(1), - ), - ( - 'uint32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT32, - shape=array_pb2.ArrayShape(dim=[]), - uint32_list=array_pb2.Array.Uint32List(value=[1]), - ), - np.uint32(1), - ), - ( - 'uint64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT64, - shape=array_pb2.ArrayShape(dim=[]), - uint64_list=array_pb2.Array.Uint64List(value=[1]), - ), - np.uint64(1), - ), - ( - 'float16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_HALF, - shape=array_pb2.ArrayShape(dim=[]), - float16_list=array_pb2.Array.IntList( - value=[np.asarray(1.0, np.float16).view(np.uint16).item()] - ), - ), - np.float16(1.0), - ), - ( - 'float32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[1.0]), - ), - np.float32(1.0), - ), - ( - 'float64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_DOUBLE, - shape=array_pb2.ArrayShape(dim=[]), - float64_list=array_pb2.Array.DoubleList(value=[1.0]), - ), - np.float64(1.0), - ), - ( - 'complex64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX64, - shape=array_pb2.ArrayShape(dim=[]), - complex64_list=array_pb2.Array.FloatList(value=[1.0, 1.0]), - ), - np.complex64(1.0 + 1.0j), - ), - ( - 'complex128', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - complex128_list=array_pb2.Array.DoubleList(value=[1.0, 1.0]), - ), - np.complex128(1.0 + 1.0j), - ), - ( - 'bfloat16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - bfloat16_list=array_pb2.Array.IntList( - value=[ - np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() - ] - ), - ), - ml_dtypes.bfloat16(1.0), - ), - ( - 'str', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - b'abc', - ), - ( - 'str_null_terminated', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc\x00\x00']), - ), - b'abc\x00\x00', - ), - ( - 'array_int32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[2, 3]), - int32_list=array_pb2.Array.IntList(value=[1, 2, 3, 4, 5, 6]), - ), - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - ), - ( - 'array_int32_empty', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[0]), - int32_list=array_pb2.Array.IntList(value=[]), - ), - np.array([], np.int32), - ), - ( - 'array_str', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - np.array([b'abc', b'def'], np.object_), - ), - ( - 'array_str_null_terminated', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList( - value=[b'abc\x00\x00', b'def\x00\x00'] - ), - ), - np.array([b'abc\x00\x00', b'def\x00\x00'], np.object_), - ), - ( - 'nan', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[np.nan]), - ), - np.float32(np.nan), - ), - ( - 'inf', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[np.inf]), - ), - np.float32(np.inf), - ), - ) - def test_returns_value(self, proto, expected_value): - actual_value = array.from_proto(proto) - - if isinstance(actual_value, (np.ndarray, np.generic)): - np.testing.assert_array_equal(actual_value, expected_value, strict=True) - else: - self.assertIsInstance(actual_value, type(expected_value)) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'complex64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX64, - shape=array_pb2.ArrayShape(dim=[]), - complex64_list=array_pb2.Array.FloatList(value=[1.0]), - ), - ), - ( - 'complex128', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - complex128_list=array_pb2.Array.DoubleList(value=[1.0]), - ), - ), - ) - def test_raises_value_error_with_wrong_value(self, proto): - with self.assertRaises(ValueError): - array.from_proto(proto) - - -class ToProtoTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'bool', - True, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BOOL, - shape=array_pb2.ArrayShape(dim=[]), - bool_list=array_pb2.Array.BoolList(value=[True]), - ), - ), - ( - 'int32', - np.iinfo(np.int32).max, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList( - value=[np.iinfo(np.int32).max] - ), - ), - ), - ( - 'int64', - np.iinfo(np.int64).max, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - int64_list=array_pb2.Array.Int64List( - value=[np.iinfo(np.int64).max] - ), - ), - ), - ( - 'float', - 1.0, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[1.0]), - ), - ), - ( - 'complex', - complex(1.0, 1.0), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - complex128_list=array_pb2.Array.DoubleList(value=[1.0, 1.0]), - ), - ), - ( - 'str', - 'abc', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'bytes', - b'abc', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'bytes_null_terminated', - b'abc\x00\x00', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc\x00\x00']), - ), - ), - ( - 'generic_int32', - np.int32(1), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'generic_float16', - np.float16(1.0), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_HALF, - shape=array_pb2.ArrayShape(dim=[]), - float16_list=array_pb2.Array.IntList( - value=[np.asarray(1.0, np.float16).view(np.uint16).item()] - ), - ), - ), - ( - 'generic_bfloat16', - ml_dtypes.bfloat16(1.0), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - bfloat16_list=array_pb2.Array.IntList( - value=[ - np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() - ] - ), - ), - ), - ( - 'generic_str', - np.str_('abc'), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'generic_bytes', - np.bytes_(b'abc'), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'array_int32', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[2, 3]), - int32_list=array_pb2.Array.IntList(value=[1, 2, 3, 4, 5, 6]), - ), - ), - ( - 'array_int32_empty', - np.array([], np.int32), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[0]), - int32_list=array_pb2.Array.IntList(value=[]), - ), - ), - ( - 'array_float16', - np.array([1.0], np.float16), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_HALF, - shape=array_pb2.ArrayShape(dim=[1]), - float16_list=array_pb2.Array.IntList( - value=[np.asarray(1.0, np.float16).view(np.uint16).item()] - ), - ), - ), - ( - 'array_bfloat16', - np.array([1.0], ml_dtypes.bfloat16), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[1]), - bfloat16_list=array_pb2.Array.IntList( - value=[ - np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() - ] - ), - ), - ), - ( - 'array_str', - np.array(['abc', 'def'], np.str_), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_bytes', - np.array([b'abc', b'def'], np.bytes_), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_object_str', - np.array(['abc', 'def'], np.object_), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_object_bytes', - np.array([b'abc', b'def'], np.object_), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_object_bytes_null_terminated', - np.array([b'abc\x00\x00', b'def\x00\x00'], np.object_), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList( - value=[b'abc\x00\x00', b'def\x00\x00'] - ), - ), - ), - ( - 'nan', - np.nan, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[np.nan]), - ), - ), - ( - 'inf', - np.inf, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[np.inf]), - ), - ), - ) - def test_returns_value_with_no_dtype_hint(self, value, expected_value): - actual_value = array.to_proto(value) - - # Externally protobuf does not compare NaN values as equal. - if isinstance(value, float) and math.isnan(value): - self.assertEqual(actual_value.dtype, expected_value.dtype) - self.assertEqual(actual_value.shape, expected_value.shape) - self.assertLen(actual_value.float32_list.value, 1) - self.assertTrue(math.isnan(actual_value.float32_list.value[0])) - else: - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'bool', - True, - np.bool_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BOOL, - shape=array_pb2.ArrayShape(dim=[]), - bool_list=array_pb2.Array.BoolList(value=[True]), - ), - ), - ( - 'int8', - 1, - np.int8, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT8, - shape=array_pb2.ArrayShape(dim=[]), - int8_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'int16', - 1, - np.int16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT16, - shape=array_pb2.ArrayShape(dim=[]), - int16_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'int32', - 1, - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'int64', - 1, - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - int64_list=array_pb2.Array.Int64List(value=[1]), - ), - ), - ( - 'uint8', - 1, - np.uint8, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT8, - shape=array_pb2.ArrayShape(dim=[]), - uint8_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'uint16', - 1, - np.uint16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT16, - shape=array_pb2.ArrayShape(dim=[]), - uint16_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'uint32', - 1, - np.uint32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT32, - shape=array_pb2.ArrayShape(dim=[]), - uint32_list=array_pb2.Array.Uint32List(value=[1]), - ), - ), - ( - 'uint64', - 1, - np.uint64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT64, - shape=array_pb2.ArrayShape(dim=[]), - uint64_list=array_pb2.Array.Uint64List(value=[1]), - ), - ), - ( - 'float16', - 1.0, - np.float16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_HALF, - shape=array_pb2.ArrayShape(dim=[]), - float16_list=array_pb2.Array.IntList( - value=[np.asarray(1.0, np.float16).view(np.uint16).item()] - ), - ), - ), - ( - 'float32', - 1.0, - np.float32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[1.0]), - ), - ), - ( - 'float64', - 1.0, - np.float64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_DOUBLE, - shape=array_pb2.ArrayShape(dim=[]), - float64_list=array_pb2.Array.DoubleList(value=[1.0]), - ), - ), - ( - 'complex64', - (1.0 + 1.0j), - np.complex64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX64, - shape=array_pb2.ArrayShape(dim=[]), - complex64_list=array_pb2.Array.FloatList(value=[1.0, 1.0]), - ), - ), - ( - 'complex128', - (1.0 + 1.0j), - np.complex128, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - complex128_list=array_pb2.Array.DoubleList(value=[1.0, 1.0]), - ), - ), - ( - 'str', - 'abc', - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'bytes', - b'abc', - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'bytes_null_terminated', - b'abc\x00\x00', - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc\x00\x00']), - ), - ), - ( - 'generic_int32', - np.int32(1), - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ), - ), - ( - 'generic_bfloat16', - # Note: we must not use Python `float` here because ml_dtypes.bfloat16 - # is declared as kind `V` (void) not `f` (float) to prevent numpy from - # trying to equate float16 and bfloat16 (which are not compatible). - ml_dtypes.bfloat16(1.0), - ml_dtypes.bfloat16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - bfloat16_list=array_pb2.Array.IntList( - value=[ - np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() - ] - ), - ), - ), - ( - 'generic_str', - np.str_('abc'), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'generic_bytes', - np.bytes_(b'abc'), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - string_list=array_pb2.Array.BytesList(value=[b'abc']), - ), - ), - ( - 'array_int32', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[2, 3]), - int32_list=array_pb2.Array.IntList(value=[1, 2, 3, 4, 5, 6]), - ), - ), - ( - 'array_int32_empty', - np.array([], np.int32), - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[0]), - int32_list=array_pb2.Array.IntList(value=[]), - ), - ), - ( - 'array_str', - np.array(['abc', 'def'], np.str_), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_bytes', - np.array([b'abc', b'def'], np.bytes_), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_object_str', - np.array(['abc', 'def'], np.object_), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_object_bytes', - np.array([b'abc', b'def'], np.object_), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList(value=[b'abc', b'def']), - ), - ), - ( - 'array_object_bytes_null_terminated', - np.array([b'abc\x00\x00', b'def\x00\x00'], np.object_), - np.str_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[2]), - string_list=array_pb2.Array.BytesList( - value=[b'abc\x00\x00', b'def\x00\x00'] - ), - ), - ), - ( - 'nan', - np.nan, - np.float32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[np.nan]), - ), - ), - ( - 'inf', - np.inf, - np.float32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - float32_list=array_pb2.Array.FloatList(value=[np.inf]), - ), - ), - ( - 'scalar_different_dtype', - 1, - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - int64_list=array_pb2.Array.Int64List(value=[1]), - ), - ), - ( - 'generic_different_dtype', - np.int32(1), - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - int64_list=array_pb2.Array.Int64List(value=[1]), - ), - ), - ( - 'array_different_dtype', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[2, 3]), - int64_list=array_pb2.Array.Int64List(value=[1, 2, 3, 4, 5, 6]), - ), - ), - ) - def test_returns_value_with_dtype_hint(self, value, dtype, expected_value): - actual_value = array.to_proto(value, dtype_hint=dtype) - - # Externally protobuf does not compare NaN values as equal. - if isinstance(value, float) and math.isnan(value): - self.assertEqual(actual_value.dtype, expected_value.dtype) - self.assertEqual(actual_value.shape, expected_value.shape) - self.assertLen(actual_value.float32_list.value, 1) - self.assertTrue(math.isnan(actual_value.float32_list.value[0])) - else: - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ('bytes', b'abc', np.bytes_), - ) - def test_raises_value_error_with_invalid_dtype_hint(self, value, dtype): - with self.assertRaises(ValueError): - array.to_proto(value, dtype_hint=dtype) - - @parameterized.named_parameters( - ('scalar', np.iinfo(np.int64).max, np.int32), - ('generic', np.int64(np.iinfo(np.int64).max), np.int32), - ('array', np.array([np.iinfo(np.int64).max] * 3, np.int64), np.int32), - ) - def test_raises_value_error_with_incompatible_dtype_hint(self, value, dtype): - with self.assertRaises(ValueError): - array.to_proto(value, dtype_hint=dtype) - - @parameterized.named_parameters( - ('None', None), - ('object', object()), - ) - def test_raises_not_implemented_error(self, value): - with self.assertRaises(NotImplementedError): - array.to_proto(value) - - -class FromProtoContentTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'bool', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BOOL, - shape=array_pb2.ArrayShape(dim=[]), - content=np.bool_(True).tobytes(), - ), - np.bool_(True), - ), - ( - 'int8', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT8, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int8(1).tobytes(), - ), - np.int8(1), - ), - ( - 'int16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT16, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int16(1).tobytes(), - ), - np.int16(1), - ), - ( - 'int32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int32(1).tobytes(), - ), - np.int32(1), - ), - ( - 'int64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int64(1).tobytes(), - ), - np.int64(1), - ), - ( - 'uint8', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT8, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint8(1).tobytes(), - ), - np.uint8(1), - ), - ( - 'uint16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT16, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint16(1).tobytes(), - ), - np.uint16(1), - ), - ( - 'uint32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint32(1).tobytes(), - ), - np.uint32(1), - ), - ( - 'uint64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint64(1).tobytes(), - ), - np.uint64(1), - ), - ( - 'float16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_HALF, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float16(1.0).tobytes(), - ), - np.float16(1.0), - ), - ( - 'float32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(1.0).tobytes(), - ), - np.float32(1.0), - ), - ( - 'float64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_DOUBLE, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float64(1.0).tobytes(), - ), - np.float64(1.0), - ), - ( - 'complex64', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.complex64(1.0 + 1.0j).tobytes(), - ), - np.complex64(1.0 + 1.0j), - ), - ( - 'complex128', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - content=np.complex128(1.0 + 1.0j).tobytes(), - ), - np.complex128(1.0 + 1.0j), - ), - ( - 'bfloat16', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - content=ml_dtypes.bfloat16(1.0).tobytes(), - ), - ml_dtypes.bfloat16(1.0), - ), - ( - 'array_int32', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[2, 3]), - content=np.array( - [[1, 2, 3], [4, 5, 6]], dtype=np.int32 - ).tobytes(), - ), - np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), - ), - ( - 'array_int32_empty', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[0]), - int32_list=array_pb2.Array.IntList(value=[]), - ), - np.array([], np.int32), - ), - ( - 'nan', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(np.nan).tobytes(), - ), - np.float32(np.nan), - ), - ( - 'inf', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(np.inf).tobytes(), - ), - np.float32(np.inf), - ), - ) - def test_returns_value(self, proto, expected_value): - actual_value = array.from_proto_content(proto) - - if isinstance(actual_value, (np.ndarray, np.generic)): - np.testing.assert_array_equal(actual_value, expected_value, strict=True) - else: - self.assertIsInstance(actual_value, type(expected_value)) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'str', - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_STRING, - shape=array_pb2.ArrayShape(dim=[]), - ), - ), - ) - def test_raises_value_error_with_invalid_dtype(self, proto): - with self.assertRaises(ValueError): - array.from_proto(proto) - - -class ToProtoContentTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'bool', - True, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BOOL, - shape=array_pb2.ArrayShape(dim=[]), - content=np.bool_(True).tobytes(), - ), - ), - ( - 'int32', - np.iinfo(np.int32).max, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int32(np.iinfo(np.int32).max).tobytes(), - ), - ), - ( - 'int64', - np.iinfo(np.int64).max, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int64(np.iinfo(np.int64).max).tobytes(), - ), - ), - ( - 'float', - 1.0, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(1.0).tobytes(), - ), - ), - ( - 'complex', - (1.0 + 1.0j), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - content=np.complex128(1.0 + 1.0j).tobytes(), - ), - ), - ( - 'generic_int32', - np.int32(1), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int32(1).tobytes(), - ), - ), - ( - 'generic_bfloat16', - ml_dtypes.bfloat16(1.0), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - content=ml_dtypes.bfloat16(1.0).tobytes(), - ), - ), - ( - 'array_int32', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[2, 3]), - content=np.array([[1, 2, 3], [4, 5, 6]], np.int32).tobytes(), - ), - ), - ( - 'array_int32_empty', - np.array([], np.int32), - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[0]), - content=np.array([], np.int32).tobytes(), - ), - ), - ( - 'nan', - np.nan, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(np.nan).tobytes(), - ), - ), - ( - 'inf', - np.inf, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(np.inf).tobytes(), - ), - ), - ) - def test_returns_value_with_no_dtype_hint(self, value, expected_value): - actual_value = array.to_proto_content(value) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'bool', - True, - np.bool_, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BOOL, - shape=array_pb2.ArrayShape(dim=[]), - content=np.bool_(True).tobytes(), - ), - ), - ( - 'int8', - 1, - np.int8, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT8, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int8(1).tobytes(), - ), - ), - ( - 'int16', - 1, - np.int16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT16, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int16(1).tobytes(), - ), - ), - ( - 'int32', - 1, - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int32(1).tobytes(), - ), - ), - ( - 'int64', - 1, - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int64(1).tobytes(), - ), - ), - ( - 'uint8', - 1, - np.uint8, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT8, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint8(1).tobytes(), - ), - ), - ( - 'uint16', - 1, - np.uint16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT16, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint16(1).tobytes(), - ), - ), - ( - 'uint32', - 1, - np.uint32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint32(1).tobytes(), - ), - ), - ( - 'uint64', - 1, - np.uint64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_UINT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.uint64(1).tobytes(), - ), - ), - ( - 'float16', - 1.0, - np.float16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_HALF, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float16(1.0).tobytes(), - ), - ), - ( - 'float32', - 1.0, - np.float32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(1.0).tobytes(), - ), - ), - ( - 'float64', - 1.0, - np.float64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_DOUBLE, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float64(1.0).tobytes(), - ), - ), - ( - 'complex64', - (1.0 + 1.0j), - np.complex64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.complex64(1.0 + 1.0j).tobytes(), - ), - ), - ( - 'complex128', - (1.0 + 1.0j), - np.complex128, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_COMPLEX128, - shape=array_pb2.ArrayShape(dim=[]), - content=np.complex128(1.0 + 1.0j).tobytes(), - ), - ), - ( - 'generic_int32', - np.int32(1), - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int32(1).tobytes(), - ), - ), - ( - 'generic_bfloat16', - ml_dtypes.bfloat16(1.0), - ml_dtypes.bfloat16, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_BFLOAT16, - shape=array_pb2.ArrayShape(dim=[]), - content=ml_dtypes.bfloat16(1.0).tobytes(), - ), - ), - ( - 'array_int32', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[2, 3]), - content=np.array([[1, 2, 3], [4, 5, 6]], np.int32).tobytes(), - ), - ), - ( - 'array_int32_empty', - np.array([], np.int32), - np.int32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[0]), - content=np.array([], np.int32).tobytes(), - ), - ), - ( - 'nan', - np.nan, - np.float32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(np.nan).tobytes(), - ), - ), - ( - 'inf', - np.inf, - np.float32, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_FLOAT, - shape=array_pb2.ArrayShape(dim=[]), - content=np.float32(np.inf).tobytes(), - ), - ), - ( - 'scalar_different_dtype', - 1, - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int64(1).tobytes(), - ), - ), - ( - 'generic_different_dtype', - np.int32(1), - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[]), - content=np.int64(1).tobytes(), - ), - ), - ( - 'array_different_dtype', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - np.int64, - array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT64, - shape=array_pb2.ArrayShape(dim=[2, 3]), - content=np.array([[1, 2, 3], [4, 5, 6]], np.int64).tobytes(), - ), - ), - ) - def test_returns_value_with_dtype_hint(self, value, dtype, expected_value): - actual_value = array.to_proto_content(value, dtype_hint=dtype) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ('bytes', b'abc', np.bytes_), - ) - def test_raises_value_error_with_invalid_dtype_hint(self, value, dtype): - with self.assertRaises(ValueError): - array.to_proto_content(value, dtype_hint=dtype) - - @parameterized.named_parameters( - ('scalar', np.iinfo(np.int64).max, np.int32), - ('generic', np.int64(np.iinfo(np.int64).max), np.int32), - ('array', np.array([np.iinfo(np.int64).max] * 3, np.int64), np.int32), - ) - def test_raises_value_error_with_incompatible_dtype_hint(self, value, dtype): - with self.assertRaises(ValueError): - array.to_proto(value, dtype_hint=dtype) - - @parameterized.named_parameters( - ('None', None), - ('object', object()), - ) - def test_raises_not_implemented_error(self, value): - with self.assertRaises(NotImplementedError): - array.to_proto(value) - - -class IsCompatibleDtypeTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('bool', True, np.bool_), - ('int8', 1, np.int8), - ('int16', 1, np.int16), - ('int32', 1, np.int32), - ('int64', 1, np.int64), - ('uint8', 1, np.uint8), - ('uint16', 1, np.uint16), - ('uint32', 1, np.uint32), - ('uint64', 1, np.uint64), - ('float16', 1.0, np.float16), - ('float32', 1.0, np.float32), - ('float64', 1.0, np.float64), - ('complex64', (1.0 + 1.0j), np.complex64), - ('complex128', (1.0 + 1.0j), np.complex128), - ('bfloat16', ml_dtypes.bfloat16(1.0), ml_dtypes.bfloat16), - ('str', 'abc', np.str_), - ('bytes', b'abc', np.str_), - ('generic_int32', np.int32(1), np.int32), - ('array_int32', np.array([[1, 2, 3], [4, 5, 6]], np.int32), np.int32), - ('array_str', np.array(['abc', 'def'], np.str_), np.str_), - ('array_bytes', np.array([b'abc', b'def'], np.bytes_), np.str_), - ( - 'array_bytes_null_terminated', - np.array([b'abc\x00\x00', b'def\x00\x00'], np.object_), - np.str_, - ), - ('nan', np.nan, np.float32), - ('inf', np.inf, np.float32), - ) - def test_returns_true(self, value, dtype): - result = array.is_compatible_dtype(value, dtype) - self.assertTrue(result) - - @parameterized.named_parameters( - ('scalar_incompatible_dtype_kind', 1, np.float32), - ('scalar_incompatible_dtype_kind_bfloat16', 1.0, ml_dtypes.bfloat16), - ('scalar_incompatible_dtype_size_int', np.iinfo(np.int64).max, np.int32), - ( - 'scalar_incompatible_dtype_size_float', - float(np.finfo(np.float64).max), - np.float32, - ), - ( - 'scalar_incompatible_dtype_size_complex_real', - complex(np.finfo(np.float64).max, 1), - np.complex64, - ), - ( - 'scalar_incompatible_dtype_size_complex_imaginary', - complex(1, np.finfo(np.float64).max), - np.complex64, - ), - ('generic_incompatible_dtype_kind', np.int32(1), np.float32), - ( - 'generic_incompatible_dtype_size', - np.int64(np.iinfo(np.int64).max), - np.int32, - ), - ( - 'array_incompatible_dtype_kind', - np.array([1, 2, 3], np.int32), - np.float32, - ), - ( - 'array_incompatible_dtype_size', - np.array([np.iinfo(np.int64).max] * 3, np.int64), - np.int32, - ), - ) - def test_returns_false(self, value, dtype): - result = array.is_compatible_dtype(value, dtype) - self.assertFalse(result) - - -class IsCompatibleShapeTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('scalar', 1, []), - ('generic', np.int32(1), []), - ('array', np.array([[1, 2, 3], [4, 5, 6]], np.int32), [2, 3]), - ) - def test_returns_true(self, value, shape): - result = array.is_compatible_shape(value, shape) - self.assertTrue(result) - - @parameterized.named_parameters( - ('scalar', 1, [3]), - ('generic', np.int32(1), [3]), - ('array', np.array([[1, 2, 3], [4, 5, 6]], np.int32), [3]), - ) - def test_returns_false(self, value, shape): - result = array.is_compatible_shape(value, shape) - self.assertFalse(result) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_analysis.py b/tensorflow_federated/python/core/impl/compiler/building_block_analysis.py deleted file mode 100644 index d052deef90..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_analysis.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of static analysis functions for building blocks.""" - -from tensorflow_federated.python.core.impl.compiler import building_blocks - - -def is_called_intrinsic(comp, uri=None): - """Tests if `comp` is a called intrinsic with the given `uri`. - - Args: - comp: The computation building block to test. - uri: An optional URI or list of URIs; the same form as what is accepted by - isinstance. - - Returns: - `True` if `comp` is a called intrinsic with the given `uri`, otherwise - `False`. - """ - if isinstance(uri, str): - uri = [uri] - return ( - isinstance(comp, building_blocks.Call) - and isinstance(comp.function, building_blocks.Intrinsic) - and (uri is None or comp.function.uri in uri) - ) - - -def is_identity_function(comp): - """Returns `True` if `comp` is an identity function, otherwise `False`.""" - return ( - isinstance(comp, building_blocks.Lambda) - and isinstance(comp.result, building_blocks.Reference) - and comp.parameter_name == comp.result.name - ) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory.py b/tensorflow_federated/python/core/impl/compiler/building_block_factory.py deleted file mode 100644 index cca8a88c9b..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory.py +++ /dev/null @@ -1,1426 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of construction functions for building block structures.""" - -from collections.abc import Iterator, Sequence -import functools -import random -import string -from typing import Optional, Union - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import type_transformations - -Index = Union[str, int] -Path = Union[Index, tuple[Index, ...]] - - -def select_output_from_lambda( - comp: building_blocks.Lambda, paths: Union[Path, list[Path]] -) -> building_blocks.Lambda: - """Constructs a new function with result of selecting `paths` from `comp`. - - Args: - comp: Lambda computation with result type `tff.StructType` from which we - wish to select the sub-results at `paths`. - paths: Either a `Path` or list of `Path`s specifying the indices we wish to - select from the result of `comp`. Each path must be a `tuple` of `str` or - `int` indices from which to select an output. If `paths` is a list, the - returned computation will have a `tff.StructType` result holding each of - the specified selections. - - Returns: - A version of `comp` with result value the selection from the result of - `comp` specified by `paths`. - """ - if not isinstance(comp.type_signature.result, computation_types.StructType): - raise ValueError( - f'Expected a `tff.StructType`, found {comp.type_signature.result}.' - ) - - def _select_path(result, path: Path): - if not isinstance(path, tuple): - path = (path,) - for index in path: - if isinstance(result, building_blocks.Struct): - result = result[index] - elif isinstance(index, str): - result = building_blocks.Selection(result, name=index) - elif isinstance(index, int): - result = building_blocks.Selection(result, index=index) - else: - raise TypeError( - 'Invalid selection type: expected `str` or `int`, ' - f'found value `{index}` of type `{type(index)}`.' - ) - return result - - if isinstance(paths, list): - # Avoid duplicating `comp.result` by binding it to a local. - result_name = next(unique_name_generator(comp)) - result_ref = building_blocks.Reference( - result_name, comp.result.type_signature - ) - elements = [_select_path(result_ref, path) for path in paths] - result = building_blocks.Block( - [(result_name, comp.result)], building_blocks.Struct(elements) - ) - else: - result = _select_path(comp.result, paths) - return building_blocks.Lambda( - comp.parameter_name, comp.parameter_type, result - ) - - -def unique_name_generator( - comp: building_blocks.ComputationBuildingBlock, prefix: str = '_var' -) -> Iterator[str]: - """Yields a new unique name that does not exist in `comp`. - - Args: - comp: The computation building block to use as a reference. - prefix: The prefix to use when generating unique names. If `prefix` is - `None` or if `comp` contains any name with this prefix, then a unique - prefix will be generated from random lowercase ascii characters. - """ - if comp is not None: - names = transformation_utils.get_unique_names(comp) - else: - names = set() - while prefix is None or any(n.startswith(prefix) for n in names): - characters = string.ascii_lowercase - prefix = '_{}'.format(''.join(random.choice(characters) for _ in range(3))) - index = 1 - while True: - yield '{}{}'.format(prefix, index) - index += 1 - - -@functools.lru_cache() -def create_identity( - type_signature: computation_types.Type, -) -> building_blocks.Lambda: - return building_blocks.Lambda( - 'id_arg', - type_signature, - building_blocks.Reference('id_arg', type_signature), - ) - - -class SelectionSpec: - """Data class representing map from input tuple to selection of result. - - Attributes: - tuple_index: The index of the source of the selection sequence in the - desired result of the generated TensorFlow. If this `SelectionSpec` - appears at index i of a list of `SelectionSpec`s, index j is the source - for the result of the generated function at index i. - selection_sequence: A list or tuple representing the selections to make from - `tuple_index`, so that the list `[0]` for example would represent the - output is the 0th element of `tuple_index`, while `[0, 0]` would represent - that the output is the 0th element of the 0th element of `tuple_index`. - """ - - def __init__(self, tuple_index: int, selection_sequence: Sequence[int]): - self._tuple_index = tuple_index - self._selection_sequence = selection_sequence - - @property - def tuple_index(self): - return self._tuple_index - - @property - def selection_sequence(self): - return self._selection_sequence - - def __str__(self): - return 'SelectionSequence(tuple_index={},selection_sequence={}'.format( - self._tuple_index, self._selection_sequence - ) - - def __repr__(self): - return str(self) - - -def create_federated_getitem_call( - arg: building_blocks.ComputationBuildingBlock, idx: Union[int, slice] -) -> building_blocks.Call: - """Creates computation building block passing getitem to federated value. - - Args: - arg: Instance of `building_blocks.ComputationBuildingBlock` of - `computation_types.FederatedType` with member of type - `computation_types.StructType` from which we wish to pick out item `idx`. - idx: Index, instance of `int` or `slice` used to address the - `computation_types.StructType` underlying `arg`. - - Returns: - Returns a `building_blocks.Call` with type signature - `computation_types.FederatedType` of same placement as `arg`, the result - of applying or mapping the appropriate `__getitem__` function, as defined - by `idx`. - """ - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(idx, (int, slice)) - py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) - py_typecheck.check_type( - arg.type_signature.member, # pytype: disable=attribute-error - computation_types.StructType, - ) - getitem_comp = create_federated_getitem_comp(arg, idx) - return create_federated_map_or_apply(getitem_comp, arg) - - -def create_federated_getattr_call( - arg: building_blocks.ComputationBuildingBlock, name: str -) -> building_blocks.Call: - """Creates computation building block passing getattr to federated value. - - Args: - arg: Instance of `building_blocks.ComputationBuildingBlock` of - `computation_types.FederatedType` with member of type - `computation_types.StructType` from which we wish to pick out item `name`. - name: String name to address the `computation_types.StructType` underlying - `arg`. - - Returns: - Returns a `building_blocks.Call` with type signature - `computation_types.FederatedType` of same placement as `arg`, - the result of applying or mapping the appropriate `__getattr__` function, - as defined by `name`. - """ - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(name, str) - py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) - py_typecheck.check_type( - arg.type_signature.member, # pytype: disable=attribute-error - computation_types.StructType, - ) - getattr_comp = create_federated_getattr_comp(arg, name) - return create_federated_map_or_apply(getattr_comp, arg) - - -def create_federated_getattr_comp( - comp: building_blocks.ComputationBuildingBlock, name: str -) -> building_blocks.Lambda: - """Function to construct computation for `federated_apply` of `__getattr__`. - - Creates a `building_blocks.ComputationBuildingBlock` - which selects `name` from its argument, of type `comp.type_signature.member`, - an instance of `computation_types.StructType`. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` with type - signature `computation_types.FederatedType` whose `member` attribute is of - type `computation_types.StructType`. - name: String name of attribute to grab. - - Returns: - Instance of `building_blocks.Lambda` which grabs attribute - according to `name` of its argument. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) - py_typecheck.check_type( - comp.type_signature.member, # pytype: disable=attribute-error - computation_types.StructType, - ) - py_typecheck.check_type(name, str) - element_names = [ - x for x, _ in comp.type_signature.member.items() # pytype: disable=attribute-error - ] - if name not in element_names: - raise ValueError( - 'The federated value has no element of name `{}`. Value: {}'.format( - name, comp.formatted_representation() - ) - ) - apply_input = building_blocks.Reference('x', comp.type_signature.member) # pytype: disable=attribute-error - selected = building_blocks.Selection(apply_input, name=name) - apply_lambda = building_blocks.Lambda( - 'x', apply_input.type_signature, selected - ) - return apply_lambda - - -def create_federated_getitem_comp( - comp: building_blocks.ComputationBuildingBlock, key: Union[int, slice] -) -> building_blocks.Lambda: - """Function to construct computation for `federated_apply` of `__getitem__`. - - Creates a `building_blocks.ComputationBuildingBlock` - which selects `key` from its argument, of type `comp.type_signature.member`, - of type `computation_types.StructType`. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` with type - signature `computation_types.FederatedType` whose `member` attribute is of - type `computation_types.StructType`. - key: Instance of `int` or `slice`, key used to grab elements from the member - of `comp`. implementation of slicing for `Value` objects with - `type_signature` `computation_types.StructType`. - - Returns: - Instance of `building_blocks.Lambda` which grabs slice - according to `key` of its argument. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) - py_typecheck.check_type( - comp.type_signature.member, # pytype: disable=attribute-error - computation_types.StructType, - ) - py_typecheck.check_type(key, (int, slice)) - apply_input = building_blocks.Reference('x', comp.type_signature.member) # pytype: disable=attribute-error - if isinstance(key, int): - selected = building_blocks.Selection(apply_input, index=key) - else: - elems = list(comp.type_signature.member.items()) # pytype: disable=attribute-error - index_range = range(*key.indices(len(elems))) - elem_list = [] - for k in index_range: - elem_list.append( - (elems[k][0], building_blocks.Selection(apply_input, index=k)) - ) - selected = building_blocks.Struct(elem_list) - apply_lambda = building_blocks.Lambda( - 'x', apply_input.type_signature, selected - ) - return apply_lambda - - -def _unname_fn_parameter(fn, unnamed_parameter_type): - """Coerces `fn` to a comp whose parameter type is `unnamed_parameter_type`.""" - if any([n for n, _ in fn.type_signature.parameter.items()]): # pytype: disable=attribute-error - return building_blocks.Lambda( - 'a', - unnamed_parameter_type, - building_blocks.Call( - fn, - building_blocks.Reference('a', unnamed_parameter_type), - ), - ) - else: - return fn - - -def create_federated_aggregate( - value: building_blocks.ComputationBuildingBlock, - zero: building_blocks.ComputationBuildingBlock, - accumulate: building_blocks.ComputationBuildingBlock, - merge: building_blocks.ComputationBuildingBlock, - report: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated aggregate. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp, Comp, Comp, Comp] - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - zero: A `building_blocks.ComputationBuildingBlock` to use as the initial - value. - accumulate: A `building_blocks.ComputationBuildingBlock` to use as the - accumulate function. - merge: A `building_blocks.ComputationBuildingBlock` to use as the merge - function. - report: A `building_blocks.ComputationBuildingBlock` to use as the report - function. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(zero, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(accumulate, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(merge, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(report, building_blocks.ComputationBuildingBlock) - # Its okay if the first argument of accumulate is assignable from the zero, - # without being the exact type. This occurs when accumulate has a type like - # ( -> int32[?]) but zero is int32[0]. - zero_arg_type = accumulate.type_signature.parameter[0] # pytype: disable=attribute-error - zero_arg_type.check_assignable_from(zero.type_signature) - result_type = computation_types.FederatedType( - report.type_signature.result, # pytype: disable=attribute-error - placements.SERVER, - ) - - accumulate_parameter_type = computation_types.StructType([ - zero_arg_type, - value.type_signature.member, # pytype: disable=attribute-error - ]) - accumulate = _unname_fn_parameter(accumulate, accumulate_parameter_type) - merge_parameter_type = computation_types.StructType( - [zero_arg_type, zero_arg_type] - ) - merge = _unname_fn_parameter(merge, merge_parameter_type) - - intrinsic_type = computation_types.FunctionType( - ( - type_conversions.type_to_non_all_equal(value.type_signature), - zero_arg_type, - accumulate.type_signature, - merge.type_signature, - report.type_signature, - ), - result_type, - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_type - ) - values = building_blocks.Struct((value, zero, accumulate, merge, report)) - return building_blocks.Call(intrinsic, values) - - -def create_federated_apply( - fn: building_blocks.ComputationBuildingBlock, - arg: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated apply. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp] - - Args: - fn: A `building_blocks.ComputationBuildingBlock` to use as the function. - arg: A `building_blocks.ComputationBuildingBlock` to use as the argument. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( - fn.type_signature.result, # pytype: disable=attribute-error - placements.SERVER, - ) - intrinsic_type = computation_types.FunctionType( - (fn.type_signature, arg.type_signature), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type - ) - values = building_blocks.Struct((fn, arg)) - return building_blocks.Call(intrinsic, values) - - -def create_federated_broadcast( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated broadcast. - - Call - / \ - Intrinsic Comp - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( - value.type_signature.member, # pytype: disable=attribute-error - placements.CLIENTS, - all_equal=True, - ) - intrinsic_type = computation_types.FunctionType( - value.type_signature, result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_BROADCAST.uri, intrinsic_type - ) - return building_blocks.Call(intrinsic, value) - - -def create_federated_eval( - fn: building_blocks.ComputationBuildingBlock, - placement: placements.PlacementLiteral, -) -> building_blocks.Call: - r"""Creates a called federated eval. - - Call - / \ - Intrinsic Comp - - Args: - fn: A `building_blocks.ComputationBuildingBlock` to use as the function. - placement: A `placements.PlacementLiteral` to use as the placement. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) - if placement is placements.CLIENTS: - uri = intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS.uri - all_equal = False - elif placement is placements.SERVER: - uri = intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri - all_equal = True - else: - raise TypeError('Unsupported placement {}.'.format(placement)) - result_type = computation_types.FederatedType( - fn.type_signature.result, # pytype: disable=attribute-error - placement, - all_equal=all_equal, - ) - intrinsic_type = computation_types.FunctionType( - fn.type_signature, result_type - ) - intrinsic = building_blocks.Intrinsic(uri, intrinsic_type) - return building_blocks.Call(intrinsic, fn) - - -def create_federated_map( - fn: building_blocks.ComputationBuildingBlock, - arg: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated map. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp] - - Args: - fn: A `building_blocks.ComputationBuildingBlock` to use as the function. - arg: A `building_blocks.ComputationBuildingBlock` to use as the argument. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - parameter_type = computation_types.FederatedType( - arg.type_signature.member, # pytype: disable=attribute-error - placements.CLIENTS, - ) - result_type = computation_types.FederatedType( - fn.type_signature.result, # pytype: disable=attribute-error - placements.CLIENTS, - ) - intrinsic_type = computation_types.FunctionType( - (fn.type_signature, parameter_type), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type - ) - values = building_blocks.Struct((fn, arg)) - return building_blocks.Call(intrinsic, values) - - -def create_federated_map_all_equal( - fn: building_blocks.ComputationBuildingBlock, - arg: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated map of equal values. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp] - - Note: The `fn` is required to be deterministic and therefore should contain no - `building_blocks.CompiledComputations`. - - Args: - fn: A `building_blocks.ComputationBuildingBlock` to use as the function. - arg: A `building_blocks.ComputationBuildingBlock` to use as the argument. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - parameter_type = computation_types.FederatedType( - arg.type_signature.member, # pytype: disable=attribute-error - placements.CLIENTS, - all_equal=True, - ) - result_type = computation_types.FederatedType( - fn.type_signature.result, # pytype: disable=attribute-error - placements.CLIENTS, - all_equal=True, - ) - intrinsic_type = computation_types.FunctionType( - (fn.type_signature, parameter_type), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, intrinsic_type - ) - values = building_blocks.Struct((fn, arg)) - return building_blocks.Call(intrinsic, values) - - -def create_federated_map_or_apply( - fn: building_blocks.ComputationBuildingBlock, - arg: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated map or apply depending on `arg`s placement. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp] - - Args: - fn: A `building_blocks.ComputationBuildingBlock` to use as the function. - arg: A `building_blocks.ComputationBuildingBlock` to use as the argument. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - if arg.type_signature.placement is placements.CLIENTS: # pytype: disable=attribute-error - if arg.type_signature.all_equal: # pytype: disable=attribute-error - return create_federated_map_all_equal(fn, arg) - else: - return create_federated_map(fn, arg) - elif arg.type_signature.placement is placements.SERVER: # pytype: disable=attribute-error - return create_federated_apply(fn, arg) - else: - raise TypeError( - 'Unsupported placement {}.'.format(arg.type_signature.placement) # pytype: disable=attribute-error - ) - - -def create_federated_mean( - value: building_blocks.ComputationBuildingBlock, - weight: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated mean. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp] - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - weight: A `building_blocks.ComputationBuildingBlock` to use as the weight or - `None`. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - if weight is not None: - py_typecheck.check_type(weight, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( - value.type_signature.member, # pytype: disable=attribute-error - placements.SERVER, - ) - if weight is not None: - intrinsic_type = computation_types.FunctionType( - ( - type_conversions.type_to_non_all_equal(value.type_signature), - type_conversions.type_to_non_all_equal(weight.type_signature), - ), - result_type, - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, intrinsic_type - ) - values = building_blocks.Struct((value, weight)) - return building_blocks.Call(intrinsic, values) - else: - intrinsic_type = computation_types.FunctionType( - type_conversions.type_to_non_all_equal(value.type_signature), - result_type, - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MEAN.uri, intrinsic_type - ) - return building_blocks.Call(intrinsic, value) - - -def create_federated_min( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated min. - - Call - / \ - Intrinsic Comp - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - - Returns: - A `building_blocks.Call`. - - Raises: - ValueError: If any of the types do not match. - """ - if not isinstance(value.type_signature, computation_types.FederatedType): - raise ValueError('Expected a federated value.') - result_type = computation_types.FederatedType( - value.type_signature.member, - placements.SERVER, - ) - intrinsic_type = computation_types.FunctionType( - type_conversions.type_to_non_all_equal(value.type_signature), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MIN.uri, intrinsic_type - ) - return building_blocks.Call(intrinsic, value) - - -def create_federated_max( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated max. - - Call - / \ - Intrinsic Comp - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - - Returns: - A `building_blocks.Call`. - - Raises: - ValueError: If any of the types do not match. - """ - if not isinstance(value.type_signature, computation_types.FederatedType): - raise ValueError('Expected a federated value.') - result_type = computation_types.FederatedType( - value.type_signature.member, - placements.SERVER, - ) - intrinsic_type = computation_types.FunctionType( - type_conversions.type_to_non_all_equal(value.type_signature), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MAX.uri, intrinsic_type - ) - return building_blocks.Call(intrinsic, value) - - -def create_federated_secure_sum( - value: building_blocks.ComputationBuildingBlock, - max_input: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called secure sum. - - Call - / \ - Intrinsic [Comp, Comp] - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - max_input: A `building_blocks.ComputationBuildingBlock` to use as the - `max_input` value. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(max_input, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( - value.type_signature.member, # pytype: disable=attribute-error - placements.SERVER, - ) - intrinsic_type = computation_types.FunctionType( - [ - type_conversions.type_to_non_all_equal(value.type_signature), - max_input.type_signature, - ], - result_type, - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_SECURE_SUM.uri, intrinsic_type - ) - values = building_blocks.Struct([value, max_input]) - return building_blocks.Call(intrinsic, values) - - -def create_federated_secure_sum_bitwidth( - value: building_blocks.ComputationBuildingBlock, - bitwidth: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called secure sum using bitwidth. - - Call - / \ - Intrinsic [Comp, Comp] - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - bitwidth: A `building_blocks.ComputationBuildingBlock` to use as the - bitwidth value. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(bitwidth, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( - value.type_signature.member, # pytype: disable=attribute-error - placements.SERVER, - ) - intrinsic_type = computation_types.FunctionType( - [ - type_conversions.type_to_non_all_equal(value.type_signature), - bitwidth.type_signature, - ], - result_type, - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri, intrinsic_type - ) - values = building_blocks.Struct([value, bitwidth]) - return building_blocks.Call(intrinsic, values) - - -def create_federated_select( - client_keys, - max_key, - server_val, - select_fn, - secure: bool, -) -> building_blocks.Call: - """Creates a called `federated_select` or `federated_secure_select`.""" - py_typecheck.check_type(client_keys, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(max_key, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(server_val, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(select_fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(secure, bool) - single_key_type = max_key.type_signature.member - select_fn_unnamed_param_type = computation_types.StructType([ - (None, server_val.type_signature.member), - (None, single_key_type), - ]) - select_fn = _unname_fn_parameter(select_fn, select_fn_unnamed_param_type) - result_type = computation_types.FederatedType( - computation_types.SequenceType(select_fn.type_signature.result), - placements.CLIENTS, - ) - intrinsic_type = computation_types.FunctionType( - [ - type_conversions.type_to_non_all_equal(client_keys.type_signature), - max_key.type_signature, - server_val.type_signature, - select_fn.type_signature, - ], - result_type, - ) - if secure: - intrinsic_def = intrinsic_defs.FEDERATED_SECURE_SELECT - else: - intrinsic_def = intrinsic_defs.FEDERATED_SELECT - intrinsic = building_blocks.Intrinsic(intrinsic_def.uri, intrinsic_type) - values = building_blocks.Struct([client_keys, max_key, server_val, select_fn]) - return building_blocks.Call(intrinsic, values) - - -def create_federated_sum( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated sum. - - Call - / \ - Intrinsic Comp - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - result_type = computation_types.FederatedType( - value.type_signature.member, # pytype: disable=attribute-error - placements.SERVER, - ) - intrinsic_type = computation_types.FunctionType( - type_conversions.type_to_non_all_equal(value.type_signature), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_SUM.uri, intrinsic_type - ) - return building_blocks.Call(intrinsic, value) - - -def create_federated_unzip( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Block: - r"""Creates a tuple of called federated maps or applies. - - Block - / \ - [value=Comp] Tuple - | - [Call, Call, ...] - / \ / \ - Intrinsic Tuple Intrinsic Tuple - | | - [Lambda(arg), Ref(value)] [Lambda(arg), Ref(value)] - \ \ - Sel(0) Sel(1) - \ \ - Ref(arg) Ref(arg) - - This function returns a tuple of federated values given a `value` with a - federated tuple type signature. - - Args: - value: A `building_blocks.ComputationBuildingBlock` with a `type_signature` - of type `computation_types.StructType` containing at least one element. - - Returns: - A `building_blocks.Block`. - - Raises: - TypeError: If any of the types do not match. - ValueError: If `value` does not contain any elements. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - named_type_signatures = list(value.type_signature.member.items()) # pytype: disable=attribute-error - length = len(named_type_signatures) - if length == 0: - raise ValueError('federated_zip is only supported on non-empty tuples.') - value_ref = building_blocks.Reference('value', value.type_signature) - elements = [] - fn_ref = building_blocks.Reference('arg', named_type_signatures) - for index, (name, _) in enumerate(named_type_signatures): - sel = building_blocks.Selection(fn_ref, index=index) - fn = building_blocks.Lambda(fn_ref.name, fn_ref.type_signature, sel) - intrinsic = create_federated_map_or_apply(fn, value_ref) - elements.append((name, intrinsic)) - result = building_blocks.Struct( - elements, - value.type_signature.member.python_container, # pytype: disable=attribute-error - ) - symbols = ((value_ref.name, value),) - return building_blocks.Block(symbols, result) - - -def create_federated_value( - value: building_blocks.ComputationBuildingBlock, - placement: placements.PlacementLiteral, -) -> building_blocks.Call: - r"""Creates a called federated value. - - Call - / \ - Intrinsic Comp - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - placement: A `placements.PlacementLiteral` to use as the placement. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - if placement is placements.CLIENTS: - uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri - elif placement is placements.SERVER: - uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri - else: - raise TypeError('Unsupported placement {}.'.format(placement)) - result_type = computation_types.FederatedType( - value.type_signature, placement, all_equal=True - ) - intrinsic_type = computation_types.FunctionType( - value.type_signature, result_type - ) - intrinsic = building_blocks.Intrinsic(uri, intrinsic_type) - return building_blocks.Call(intrinsic, value) - - -def _check_placements(placement_values: set[placements.PlacementLiteral]): - """Checks if the placements of the values being zipped are compatible.""" - if not placement_values: - raise TypeError( - 'federated_zip is only supported on nested structures ' - 'containing at least one FederatedType, but none were ' - 'found.' - ) - elif len(placement_values) > 1: - placement_list = ', '.join(placement.name for placement in placement_values) - raise TypeError( - 'federated_zip requires all nested FederatedTypes to ' - 'have the same placement, but values placed at ' - f'{placement_list} were found.' - ) - - -def create_federated_zip( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called federated zip. - - This function accepts a value whose type signature is a (potentially) nested - tuple structure of federated values all with the same placement, and uses - one of the federated_zip intrinsics (at client or at server) to promote the - placement to the highest level. E.g., A value of type ', C@S>>' - would be mapped to a value of type ', C>>@S'. - - Args: - value: A `building_blocks.ComputationBuildingBlock` with a `type_signature` - of type `computation_types.StructType` that may contain other nested - `computation_types.StructTypes` bottoming out in at least one element of - type `computation_Types.FederatedType`. These federated types must be at - the same placement. - - Returns: - A `building_blocks.Call` whose type signature is now a federated - `computation_types.StructType`, placed at the same placement as the - leaves of `value`. - - Raises: - TypeError: If any of the types do not match. - ValueError: If `value` does not contain any elements. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(value.type_signature, computation_types.StructType) - - all_placements = set() - - def _record_placements(type_signature: computation_types.Type): - """Records the placements in `type_signature` to `all_placements`.""" - if isinstance(type_signature, computation_types.FederatedType): - all_placements.add(type_signature.placement) - elif isinstance(type_signature, computation_types.StructType): - for i, _ in enumerate(type_signature): - _record_placements(type_signature[i]) - else: - raise TypeError( - 'Expected type signatures consisting of structures of StructType ' - 'bottoming out in FederatedType, found: \n{}'.format(type_signature) - ) - - _record_placements(value.type_signature) - _check_placements(all_placements) - placement = all_placements.pop() - if placement is placements.CLIENTS: - uri = intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri - elif placement is placements.SERVER: - uri = intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri - else: - raise TypeError('Unsupported placement {}.'.format(placement)) - - def normalize_all_equals(element_type): - if ( - isinstance(element_type, computation_types.FederatedType) - and element_type.placement is placements.CLIENTS - and element_type.all_equal - ): - return ( - computation_types.FederatedType( - element_type.member, placements.CLIENTS - ), - True, - ) - return element_type, False - - normalized_input_type, _ = type_transformations.transform_type_postorder( - value.type_signature, normalize_all_equals - ) - - unplaced_output_type = type_transformations.strip_placement( - value.type_signature - ) - output_type = computation_types.FederatedType(unplaced_output_type, placement) - intrinsic_type = computation_types.FunctionType( - normalized_input_type, output_type - ) - intrinsic = building_blocks.Intrinsic(uri, intrinsic_type) - return building_blocks.Call(intrinsic, value) - - -def create_sequence_map( - fn: building_blocks.ComputationBuildingBlock, - arg: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called sequence map. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp] - - Args: - fn: A `building_blocks.ComputationBuildingBlock` to use as the function. - arg: A `building_blocks.ComputationBuildingBlock` to use as the argument. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) - result_type = computation_types.SequenceType(fn.type_signature.result) # pytype: disable=attribute-error - intrinsic_type = computation_types.FunctionType( - (fn.type_signature, arg.type_signature), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type - ) - values = building_blocks.Struct((fn, arg)) - return building_blocks.Call(intrinsic, values) - - -def create_sequence_reduce( - value: building_blocks.ComputationBuildingBlock, - zero: building_blocks.ComputationBuildingBlock, - op: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called sequence reduce. - - Call - / \ - Intrinsic Tuple - | - [Comp, Comp, Comp] - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - zero: A `building_blocks.ComputationBuildingBlock` to use as the initial - value. - op: A `building_blocks.ComputationBuildingBlock` to use as the op function. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(zero, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(op, building_blocks.ComputationBuildingBlock) - op_parameter_type = computation_types.StructType([ - zero.type_signature, - value.type_signature.element, # pytype: disable=attribute-error - ]) - op = _unname_fn_parameter(op, op_parameter_type) - intrinsic_type = computation_types.FunctionType( - ( - value.type_signature, - zero.type_signature, - op.type_signature, - ), - op.type_signature.result, # pytype: disable=attribute-error - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type - ) - values = building_blocks.Struct((value, zero, op)) - return building_blocks.Call(intrinsic, values) - - -def create_sequence_sum( - value: building_blocks.ComputationBuildingBlock, -) -> building_blocks.Call: - r"""Creates a called sequence sum. - - Call - / \ - Intrinsic Comp - - Args: - value: A `building_blocks.ComputationBuildingBlock` to use as the value. - - Returns: - A `building_blocks.Call`. - - Raises: - TypeError: If any of the types do not match. - """ - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - intrinsic_type = computation_types.FunctionType( - value.type_signature, - value.type_signature.element, # pytype: disable=attribute-error - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type - ) - return building_blocks.Call(intrinsic, value) - - -def _create_naming_function(tuple_type_to_name, names_to_add, container_type): - """Private function to construct lambda naming a given tuple type. - - Args: - tuple_type_to_name: Instance of `computation_types.StructType`, the type of - the argument which we wish to name. - names_to_add: Python `list` or `tuple`, the names we wish to give to - `tuple_type_to_name`. - container_type: Optional Python container type to associate with the - resulting tuple. - - Returns: - An instance of `building_blocks.Lambda` representing a function - which will take an argument of type `tuple_type_to_name` and return a tuple - with the same elements, but with names in `names_to_add` attached. - - Raises: - ValueError: If `tuple_type_to_name` and `names_to_add` have different - lengths. - """ - py_typecheck.check_type(tuple_type_to_name, computation_types.StructType) - if len(names_to_add) != len(tuple_type_to_name): # pytype: disable=wrong-arg-types - raise ValueError( - 'Number of elements in `names_to_add` must match number of element in ' - 'the named tuple type `tuple_type_to_name`; here, `names_to_add` has ' - '{} elements and `tuple_type_to_name` has {}.'.format( - len(names_to_add), # pytype: disable=wrong-arg-types - len(tuple_type_to_name), # pytype: disable=wrong-arg-types - ) - ) - naming_lambda_arg = building_blocks.Reference('x', tuple_type_to_name) - - def _create_struct_element(i): - return ( - names_to_add[i], - building_blocks.Selection(naming_lambda_arg, index=i), - ) - - named_result = building_blocks.Struct( - [_create_struct_element(k) for k in range(len(names_to_add))], - container_type, - ) - return building_blocks.Lambda( - 'x', naming_lambda_arg.type_signature, named_result - ) - - -def create_named_tuple( - comp: building_blocks.ComputationBuildingBlock, - names: Sequence[str], - container_type=None, -) -> building_blocks.ComputationBuildingBlock: - """Creates a computation that applies `names` to `comp`. - - Args: - comp: A `building_blocks.ComputationBuildingBlock` with a `type_signature` - of type `computation_types.StructType`. - names: Python `tuple` or `list` containing instances of type `str` or - `None`, the names to apply to `comp`. - container_type: Optional Python container type to associated with the - resulting tuple. - - Returns: - A `building_blocks.ComputationBuildingBlock` representing a - tuple with the elements from `comp` and the names from `names` attached to - the `type_signature` of those elements. - - Raises: - TypeError: If the types do not match. - """ - py_typecheck.check_type(names, (list, tuple)) - if not all(isinstance(x, (str, type(None))) for x in names): - raise TypeError( - 'Expected `names` containing only instances of `str` or ' - '`None`, found {}'.format(names) - ) - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(comp.type_signature, computation_types.StructType) - fn = _create_naming_function(comp.type_signature, names, container_type) - return building_blocks.Call(fn, comp) - - -def zip_to_match_type( - *, - comp_to_zip: building_blocks.ComputationBuildingBlock, - target_type: computation_types.Type, -) -> Optional[building_blocks.ComputationBuildingBlock]: - """Zips computation argument to match target type. - - This function will apply the appropriate federated zips to match `comp_to_zip` - to the requested type `target_type`, subject to a few caveats. We will - traverse `computation_types.StructTypes` to match types, so for example we - would zip `<>` to match `<@P>`, but we will not traverse - `computation_types.FunctionTypes`. Therefore we would not apply a zip to the - parameter of `(<> -> Q)` to match (<@P> -> Q). - - If zipping in this manner cannot match the type of `comp_to_zip` to - `target_type`, `None` will be returned. - - Args: - comp_to_zip: Instance of `building_blocks.ComputationBuildingBlock` to - traverse and attempt to zip to match `target_type`. - target_type: The type to target when traversing and zipping `comp_to_zip`. - - Returns: - Either a potentially transformed version of `comp_to_zip` or `None`, - depending on whether inserting a zip according to the semantics above - can transformed `comp_to_zip` to the requested type. - """ - py_typecheck.check_type(comp_to_zip, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(target_type, computation_types.Type) - - def _can_be_zipped_into( - source_type: computation_types.Type, target_type: computation_types.Type - ) -> bool: - """Indicates possibility of the transformation `zip_to_match_type`.""" - - def _struct_can_be_zipped_to_federated( - struct_type: computation_types.StructType, - federated_type: computation_types.FederatedType, - ) -> bool: - placements_encountered = set() - - def _remove_placement( - subtype: computation_types.Type, - ) -> tuple[computation_types.Type, bool]: - if isinstance(subtype, computation_types.FederatedType): - placements_encountered.add(subtype.placement) - return subtype.member, True - return subtype, False - - unplaced_struct, _ = type_transformations.transform_type_postorder( - struct_type, _remove_placement - ) - if not ( - all(x is federated_type.placement for x in placements_encountered) - ): - return False - if ( - federated_type.placement is placements.CLIENTS - and federated_type.all_equal - ): - # There is no all-equal clients zip; return false. - return False - return federated_type.member.is_assignable_from(unplaced_struct) - - def _struct_elem_zippable( - source_name, source_element, target_name, target_element - ): - return _can_be_zipped_into( - source_element, target_element - ) and source_name in (target_name, None) - - if isinstance(source_type, computation_types.StructType): - if isinstance(target_type, computation_types.FederatedType): - return _struct_can_be_zipped_to_federated(source_type, target_type) - elif isinstance(target_type, computation_types.StructType): - elements_zippable = [] - for (s_name, s_el), (t_name, t_el) in zip( - source_type.items(), - target_type.items(), - ): - elements_zippable.append( - _struct_elem_zippable(s_name, s_el, t_name, t_el) - ) - return all(elements_zippable) - else: - return target_type.is_assignable_from(source_type) - - def _zip_to_match( - *, - source: building_blocks.ComputationBuildingBlock, - target_type: computation_types.Type, - ): - if isinstance(target_type, computation_types.FederatedType) and isinstance( - source.type_signature, computation_types.StructType - ): - return create_federated_zip(source) - elif isinstance(target_type, computation_types.StructType) and isinstance( - source.type_signature, computation_types.StructType - ): - zipped_elements = [] - # Bind a reference to the source to prevent duplication in the AST. - ref_name = next(unique_name_generator(source)) - ref_to_source = building_blocks.Reference(ref_name, source.type_signature) - for idx, ((_, t_el), (s_name, _)) in enumerate( - zip( - target_type.items(), - source.type_signature.items(), - ) - ): - s_selection = building_blocks.Selection(ref_to_source, index=idx) - zipped_elements.append( - (s_name, _zip_to_match(source=s_selection, target_type=t_el)) - ) - # Insert binding above the constructed structure. - return building_blocks.Block( - [(ref_name, source)], building_blocks.Struct(zipped_elements) - ) - else: - # No zipping to be done here. - return source - - if target_type.is_assignable_from(comp_to_zip.type_signature): - # No zipping needs to be done; return directly. - return comp_to_zip - elif _can_be_zipped_into(comp_to_zip.type_signature, target_type): - return _zip_to_match(source=comp_to_zip, target_type=target_type) - else: - # Zipping cannot be performed here. - return None diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py deleted file mode 100644 index 57ae9b65f2..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test.py +++ /dev/null @@ -1,1638 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.common_libs import golden -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_test_utils - - -class UniqueNameGeneratorTest(absltest.TestCase): - - def test_does_not_raise_type_error_with_none_comp(self): - try: - building_block_factory.unique_name_generator(None) - except TypeError: - self.fail('Raised TypeError unexpectedly.') - - def test_returns_unique_names_with_none_comp_and_none_prefix(self): - name_generator = building_block_factory.unique_name_generator( - None, prefix=None - ) - names = set(next(name_generator) for _ in range(10)) - first_name = list(names)[0] - prefix = first_name[:3] - self.assertLen(names, 10) - self.assertTrue(all(n.startswith(prefix) for n in names)) - - def test_returns_unique_names_with_none_comp_and_unset_prefix(self): - name_generator = building_block_factory.unique_name_generator(None) - names = set(next(name_generator) for _ in range(10)) - self.assertLen(names, 10) - self.assertTrue(all(n.startswith('_var') for n in names)) - - def test_returns_unique_names_with_none_comp_and_prefix(self): - name_generator = building_block_factory.unique_name_generator( - None, prefix='_test' - ) - names = set(next(name_generator) for _ in range(10)) - self.assertLen(names, 10) - self.assertTrue(all(n.startswith('_test') for n in names)) - - def test_returns_unique_names_with_comp_and_none_prefix(self): - ref = building_blocks.Reference('a', np.int32) - comp = building_blocks.Lambda(ref.name, ref.type_signature, ref) - name_generator = building_block_factory.unique_name_generator( - comp, prefix=None - ) - names = set(next(name_generator) for _ in range(10)) - first_name = list(names)[0] - prefix = first_name[:3] - self.assertLen(names, 10) - self.assertTrue(all(n.startswith(prefix) for n in names)) - - def test_returns_unique_names_with_comp_and_unset_prefix(self): - ref = building_blocks.Reference('a', np.int32) - comp = building_blocks.Lambda(ref.name, ref.type_signature, ref) - name_generator = building_block_factory.unique_name_generator(comp) - names = set(next(name_generator) for _ in range(10)) - self.assertLen(names, 10) - self.assertTrue(all(n.startswith('_var') for n in names)) - - def test_returns_unique_names_with_comp_and_prefix(self): - ref = building_blocks.Reference('a', np.int32) - comp = building_blocks.Lambda(ref.name, ref.type_signature, ref) - name_generator = building_block_factory.unique_name_generator( - comp, prefix='_test' - ) - names = set(next(name_generator) for _ in range(10)) - self.assertLen(names, 10) - self.assertTrue(all(n.startswith('_test') for n in names)) - - def test_returns_unique_names_with_conflicting_prefix(self): - ref = building_blocks.Reference('_test', np.int32) - comp = building_blocks.Lambda(ref.name, ref.type_signature, ref) - name_generator = building_block_factory.unique_name_generator( - comp, prefix='_test' - ) - names = set(next(name_generator) for _ in range(10)) - first_name = list(names)[0] - prefix = first_name[:3] - self.assertNotEqual(prefix, '_test') - self.assertTrue(all(n.startswith(prefix) for n in names)) - - -class CreateFederatedGetitemCompTest(parameterized.TestCase): - - def test_raises_type_error_on_none(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_getitem_comp(None, 0) - - @parameterized.named_parameters( - ('clients', placements.CLIENTS), ('server', placements.SERVER) - ) - def test_returns_comp(self, placement): - federated_value = building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placement - ), - ) - get_0_comp = building_block_factory.create_federated_getitem_comp( - federated_value, 0 - ) - self.assertEqual(str(get_0_comp), '(x -> x[0])') - get_slice_comp = building_block_factory.create_federated_getitem_comp( - federated_value, slice(None, None, -1) - ) - self.assertEqual(str(get_slice_comp), '(x -> )') - - -class CreateFederatedGetattrCompTest(parameterized.TestCase): - - def test_raises_type_error_on_none(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_getattr_comp(None, 'x') - - @parameterized.named_parameters( - ('clients', placements.CLIENTS), ('server', placements.SERVER) - ) - def test_returns_comp(self, placement): - federated_value = building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placement - ), - ) - get_a_comp = building_block_factory.create_federated_getattr_comp( - federated_value, 'a' - ) - self.assertEqual(str(get_a_comp), '(x -> x.a)') - get_b_comp = building_block_factory.create_federated_getattr_comp( - federated_value, 'b' - ) - self.assertEqual(str(get_b_comp), '(x -> x.b)') - non_federated_arg = building_blocks.Reference( - 'test', computation_types.StructType([('a', np.int32), ('b', np.bool_)]) - ) - with self.assertRaises(TypeError): - _ = building_block_factory.create_federated_getattr_comp( - non_federated_arg, 'a' - ) - with self.assertRaisesRegex(ValueError, 'has no element of name `c`'): - _ = building_block_factory.create_federated_getattr_comp( - federated_value, 'c' - ) - - -class CreateFederatedGetattrCallTest(parameterized.TestCase): - - def test_raises_type_error_on_none(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_getattr_call(None, 'x') - - @parameterized.named_parameters( - ('clients', placements.CLIENTS), - ('server', placements.SERVER), - ) - def test_returns_named(self, placement): - federated_comp_named = building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_), np.int32], placement - ), - ) - self.assertEqual( - str(federated_comp_named.type_signature.member), - '', - ) - name_a = building_block_factory.create_federated_getattr_call( - federated_comp_named, 'a' - ) - name_b = building_block_factory.create_federated_getattr_call( - federated_comp_named, 'b' - ) - self.assertIsInstance( - name_a.type_signature, computation_types.FederatedType - ) - self.assertIsInstance( - name_b.type_signature, computation_types.FederatedType - ) - self.assertEqual(str(name_a.type_signature.member), 'int32') - self.assertEqual(str(name_b.type_signature.member), 'bool') - try: - type_analysis.check_federated_type( - name_a.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - try: - type_analysis.check_federated_type( - name_b.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - with self.assertRaisesRegex(ValueError, 'has no element of name `c`'): - _ = building_block_factory.create_federated_getattr_call( - federated_comp_named, 'c' - ) - - -class CreateFederatedGetitemCallTest(parameterized.TestCase): - - def test_fails_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_getitem_call(None, 0) - - @parameterized.named_parameters( - ('clients', placements.CLIENTS), - ('server', placements.SERVER), - ) - def test_returns_named(self, placement): - federated_comp_named = building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placement - ), - ) - self.assertEqual( - str(federated_comp_named.type_signature.member), '' - ) - idx_0 = building_block_factory.create_federated_getitem_call( - federated_comp_named, 0 - ) - idx_1 = building_block_factory.create_federated_getitem_call( - federated_comp_named, 1 - ) - self.assertIsInstance(idx_0.type_signature, computation_types.FederatedType) - self.assertIsInstance(idx_1.type_signature, computation_types.FederatedType) - self.assertEqual(str(idx_0.type_signature.member), 'int32') - self.assertEqual(str(idx_1.type_signature.member), 'bool') - try: - type_analysis.check_federated_type( - idx_0.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - try: - type_analysis.check_federated_type( - idx_1.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - flipped = building_block_factory.create_federated_getitem_call( - federated_comp_named, slice(None, None, -1) - ) - self.assertIsInstance( - flipped.type_signature, computation_types.FederatedType - ) - self.assertEqual(str(flipped.type_signature.member), '') - try: - type_analysis.check_federated_type( - flipped.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - - @parameterized.named_parameters( - ('clients', placements.CLIENTS), - ('server', placements.SERVER), - ) - def test_returns_unnamed(self, placement): - federated_comp_unnamed = building_blocks.Reference( - 'test', computation_types.FederatedType([np.int32, np.bool_], placement) - ) - self.assertEqual( - str(federated_comp_unnamed.type_signature.member), '' - ) - unnamed_idx_0 = building_block_factory.create_federated_getitem_call( - federated_comp_unnamed, 0 - ) - unnamed_idx_1 = building_block_factory.create_federated_getitem_call( - federated_comp_unnamed, 1 - ) - self.assertIsInstance( - unnamed_idx_0.type_signature, computation_types.FederatedType - ) - self.assertIsInstance( - unnamed_idx_1.type_signature, computation_types.FederatedType - ) - self.assertEqual(str(unnamed_idx_0.type_signature.member), 'int32') - self.assertEqual(str(unnamed_idx_1.type_signature.member), 'bool') - try: - type_analysis.check_federated_type( - unnamed_idx_0.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - try: - type_analysis.check_federated_type( - unnamed_idx_1.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - unnamed_flipped = building_block_factory.create_federated_getitem_call( - federated_comp_unnamed, slice(None, None, -1) - ) - self.assertIsInstance( - unnamed_flipped.type_signature, computation_types.FederatedType - ) - self.assertEqual(str(unnamed_flipped.type_signature.member), '') - try: - type_analysis.check_federated_type( - unnamed_flipped.type_signature, placement=placement - ) - except TypeError: - self.fail( - "Function 'check_federated_type' raised TypeError unexpectedly." - ) - - -class CreateFederatedAggregateTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - zero = building_blocks.Literal(0, computation_types.TensorType(np.int32)) - accumulate_type = computation_types.StructType((np.int32, np.int32)) - accumulate_result = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ) - accumulate = building_blocks.Lambda('x', accumulate_type, accumulate_result) - merge_type = computation_types.StructType((np.int32, np.int32)) - merge_result = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - merge = building_blocks.Lambda('x', merge_type, merge_result) - report_ref = building_blocks.Reference('r', np.int32) - report = building_blocks.Lambda( - report_ref.name, report_ref.type_signature, report_ref - ) - with self.assertRaises(TypeError): - building_block_factory.create_federated_aggregate( - None, zero, accumulate, merge, report - ) - - def test_raises_type_error_with_none_zero(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(0, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - accumulate_type = computation_types.StructType((np.int32, np.int32)) - accumulate_result = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ) - accumulate = building_blocks.Lambda('x', accumulate_type, accumulate_result) - merge_type = computation_types.StructType((np.int32, np.int32)) - merge_result = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - merge = building_blocks.Lambda('x', merge_type, merge_result) - report_ref = building_blocks.Reference('r', np.int32) - report = building_blocks.Lambda( - report_ref.name, report_ref.type_signature, report_ref - ) - with self.assertRaises(TypeError): - building_block_factory.create_federated_aggregate( - value, None, accumulate, merge, report - ) - - def test_raises_type_error_with_none_accumulate(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(0, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - zero = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - merge_type = computation_types.StructType((np.int32, np.int32)) - merge_result = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - merge = building_blocks.Lambda('x', merge_type, merge_result) - report_ref = building_blocks.Reference('r', np.int32) - report = building_blocks.Lambda( - report_ref.name, report_ref.type_signature, report_ref - ) - with self.assertRaises(TypeError): - building_block_factory.create_federated_aggregate( - value, zero, None, merge, report - ) - - def test_raises_type_error_with_none_merge(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(0, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - zero = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - accumulate_type = computation_types.StructType((np.int32, np.int32)) - accumulate_result = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - accumulate = building_blocks.Lambda('x', accumulate_type, accumulate_result) - report_ref = building_blocks.Reference('r', np.int32) - report = building_blocks.Lambda( - report_ref.name, report_ref.type_signature, report_ref - ) - with self.assertRaises(TypeError): - building_block_factory.create_federated_aggregate( - value, zero, accumulate, None, report - ) - - def test_raises_type_error_with_none_report(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(0, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - zero = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - accumulate_type = computation_types.StructType((np.int32, np.int32)) - accumulate_result = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - accumulate = building_blocks.Lambda('x', accumulate_type, accumulate_result) - merge_type = computation_types.StructType((np.int32, np.int32)) - merge_result = building_blocks.Literal( - 3, computation_types.TensorType(np.int32) - ) - merge = building_blocks.Lambda('x', merge_type, merge_result) - with self.assertRaises(TypeError): - building_block_factory.create_federated_aggregate( - value, zero, accumulate, merge, None - ) - - def test_returns_federated_aggregate(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(0, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - zero = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - accumulate_type = computation_types.StructType((np.int32, np.int32)) - accumulate_result = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - accumulate = building_blocks.Lambda('x', accumulate_type, accumulate_result) - merge_type = computation_types.StructType((np.int32, np.int32)) - merge_result = building_blocks.Literal( - 3, computation_types.TensorType(np.int32) - ) - merge = building_blocks.Lambda('x', merge_type, merge_result) - report_ref = building_blocks.Reference('r', np.int32) - report = building_blocks.Lambda( - report_ref.name, report_ref.type_signature, report_ref - ) - comp = building_block_factory.create_federated_aggregate( - value, zero, accumulate, merge, report - ) - self.assertEqual( - comp.compact_representation(), - 'federated_aggregate( 2),(x ->' - ' 3),(r -> r)>)', - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - -class CreateFederatedApplyTest(absltest.TestCase): - - def test_raises_type_error_with_none_fn(self): - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - building_block_factory.create_federated_apply(None, arg) - - def test_raises_type_error_with_none_arg(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - with self.assertRaises(TypeError): - building_block_factory.create_federated_apply(fn, None) - - def test_returns_federated_apply(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.SERVER, - ) - comp = building_block_factory.create_federated_apply(fn, arg) - self.assertEqual( - comp.compact_representation(), - 'federated_apply(<(x -> x),federated_value_at_server(1)>)', - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - -class CreateFederatedBroadcastTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_broadcast(None) - - def test_returns_federated_broadcast(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.SERVER, - ) - comp = building_block_factory.create_federated_broadcast(value) - self.assertEqual( - comp.compact_representation(), - 'federated_broadcast(federated_value_at_server(1))', - ) - self.assertEqual(str(comp.type_signature), 'int32@CLIENTS') - - -class CreateFederatedEvalTest(absltest.TestCase): - - def assert_type_error(self, fn, placement): - with self.assertRaises(TypeError): - building_block_factory.create_federated_eval(fn, placement) - - def test_raises_type_error_with_none_fn(self): - self.assert_type_error(None, placements.CLIENTS) - - def test_raises_type_error_with_nonfunctional_fn(self): - fn = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - self.assert_type_error(fn, placements.CLIENTS) - - def test_returns_federated_eval(self): - fn = building_blocks.Reference( - 'y', computation_types.FunctionType(None, np.int32) - ) - comp = building_block_factory.create_federated_eval(fn, placements.CLIENTS) - self.assertEqual( - comp.compact_representation(), 'federated_eval_at_clients(y)' - ) - self.assertEqual(str(comp.type_signature), '{int32}@CLIENTS') - - -class CreateFederatedMapTest(absltest.TestCase): - - def test_raises_type_error_with_none_fn(self): - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - building_block_factory.create_federated_map(None, arg) - - def test_raises_type_error_with_none_arg(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - with self.assertRaises(TypeError): - building_block_factory.create_federated_map(fn, None) - - def test_returns_federated_map(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_map(fn, arg) - self.assertEqual( - comp.compact_representation(), - 'federated_map(<(x -> x),federated_value_at_clients(1)>)', - ) - self.assertEqual(str(comp.type_signature), '{int32}@CLIENTS') - - -class CreateFederatedMapAllEqualTest(absltest.TestCase): - - def test_raises_type_error_with_none_fn(self): - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - building_block_factory.create_federated_map_all_equal(None, arg) - - def test_raises_type_error_with_none_arg(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - with self.assertRaises(TypeError): - building_block_factory.create_federated_map_all_equal(fn, None) - - def test_returns_federated_map_all_equal(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_map_all_equal(fn, arg) - self.assertEqual( - comp.compact_representation(), - 'federated_map_all_equal(<(x -> x),federated_value_at_clients(1)>)', - ) - self.assertEqual(str(comp.type_signature), 'int32@CLIENTS') - - -class CreateFederatedMapOrApplyTest(absltest.TestCase): - - def test_raises_type_error_with_none_fn(self): - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - building_block_factory.create_federated_map_or_apply(None, arg) - - def test_raises_type_error_with_none_arg(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - with self.assertRaises(TypeError): - building_block_factory.create_federated_map_or_apply(fn, None) - - def test_returns_federated_apply(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.SERVER, - ) - comp = building_block_factory.create_federated_map_or_apply(fn, arg) - self.assertEqual( - comp.compact_representation(), - 'federated_apply(<(x -> x),federated_value_at_server(1)>)', - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - def test_returns_federated_map(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_map_or_apply(fn, arg) - self.assertEqual( - comp.compact_representation(), - 'federated_map_all_equal(<(x -> x),federated_value_at_clients(1)>)', - ) - self.assertEqual(str(comp.type_signature), 'int32@CLIENTS') - - -class CreateFederatedMeanTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_mean(None, None) - - def test_returns_federated_mean(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_mean(value, None) - self.assertEqual( - comp.compact_representation(), - 'federated_mean(federated_value_at_clients(1))', - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - def test_returns_federated_weighted_mean(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - weight = building_block_factory.create_federated_value( - building_blocks.Literal(2, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_mean(value, weight) - self.assertEqual( - comp.compact_representation(), - 'federated_weighted_mean()', - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - -class CreateFederatedMinTest(absltest.TestCase): - - def test_returns_federated_min(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_min(value) - self.assertEqual( - comp.compact_representation(), - 'federated_min(federated_value_at_clients(1))', - ) - self.assertEqual( - comp.type_signature.compact_representation(), 'int32@SERVER' - ) - - -class CreateFederatedMaxTest(absltest.TestCase): - - def test_returns_federated_max(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_max(value) - self.assertEqual( - comp.compact_representation(), - 'federated_max(federated_value_at_clients(1))', - ) - self.assertEqual( - comp.type_signature.compact_representation(), 'int32@SERVER' - ) - - -class CreateFederatedSecureSumTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - max_input = mock.create_autospec( - building_blocks.CompiledComputation, spec_set=True, instance=True - ) - - with self.assertRaises(TypeError): - building_block_factory.create_federated_secure_sum(None, max_input) - - def test_raises_type_error_with_none_max_input(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - - with self.assertRaises(TypeError): - building_block_factory.create_federated_secure_sum(value, None) - - def test_returns_federated_sum(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - max_value_type = computation_types.TensorType(np.int32) - max_value = building_blocks.Literal(2, max_value_type) - comp = building_block_factory.create_federated_secure_sum(value, max_value) - self.assertEqual( - comp.compact_representation(), - 'federated_secure_sum()', - ) - self.assertEqual( - comp.type_signature.compact_representation(), 'int32@SERVER' - ) - - -class CreateFederatedSecureSumBitwidthTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - bitwidth = mock.create_autospec( - building_blocks.CompiledComputation, spec_set=True, instance=True - ) - - with self.assertRaises(TypeError): - building_block_factory.create_federated_secure_sum_bitwidth( - None, bitwidth - ) - - def test_raises_type_error_with_none_bitwidth(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - - with self.assertRaises(TypeError): - building_block_factory.create_federated_secure_sum_bitwidth(value, None) - - def test_returns_federated_sum(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - bitwidth_type = computation_types.TensorType(np.int32) - bitwidth = building_blocks.Literal(2, bitwidth_type) - comp = building_block_factory.create_federated_secure_sum_bitwidth( - value, bitwidth - ) - self.assertEqual( - comp.compact_representation(), - 'federated_secure_sum_bitwidth()', - ) - self.assertEqual( - comp.type_signature.compact_representation(), 'int32@SERVER' - ) - - -class CreateFederatedSelectTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('non_secure', False, 'federated_select'), - ('secure', True, 'federated_secure_select'), - ) - def test_returns_federated_select(self, secure, name): - client_keys = building_block_factory.create_federated_value( - building_blocks.Literal( - np.array([5, 4, 3, 2, 1], dtype=np.int32), - computation_types.TensorType(np.int32, [5]), - ), - placement=placements.CLIENTS, - ) - max_key = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.SERVER, - ) - server_val_type = computation_types.SequenceType(np.str_) - server_val = building_blocks.Reference( - 'server_val', - computation_types.FederatedType(server_val_type, placements.SERVER), - ) - select_fn = building_blocks.Reference( - 'select_fn', - computation_types.FunctionType( - computation_types.StructType([ - ('some_name_for_server_val', server_val_type), - ('some_name_for_key', np.int32), - ]), - np.str_, - ), - ) - comp = building_block_factory.create_federated_select( - client_keys, max_key, server_val, select_fn, secure - ) - self.assertEqual( - comp.compact_representation(), - f'{name}( select_fn(a))>)', - ) - self.assertEqual( - comp.type_signature.compact_representation(), '{str*}@CLIENTS' - ) - - -class CreateFederatedSumTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_sum(None) - - def test_returns_federated_sum(self): - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, - ) - comp = building_block_factory.create_federated_sum(value) - self.assertEqual( - comp.compact_representation(), - 'federated_sum(federated_value_at_clients(1))', - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - -class CreateFederatedUnzipTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_unzip(None) - - def test_returns_tuple_federated_map_with_empty_value(self): - value_type = computation_types.FederatedType([], placements.CLIENTS) - value = building_blocks.Reference('v', value_type) - with self.assertRaises(ValueError): - building_block_factory.create_federated_unzip(value) - - def test_returns_tuple_federated_map_with_one_value_unnamed(self): - value_type = computation_types.FederatedType( - (np.int32,), placements.CLIENTS - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - self.assertEqual( - comp.compact_representation(), - '(let value=v in arg[0]),value>)>)', - ) - self.assertEqual(str(comp.type_signature), '<{int32}@CLIENTS>') - - def test_returns_tuple_federated_map_with_one_value_named(self): - type_signature = computation_types.StructType((('a', np.int32),)) - value_type = computation_types.FederatedType( - type_signature, placements.CLIENTS - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - self.assertEqual( - comp.compact_representation(), - '(let value=v in arg[0]),value>)>)', - ) - self.assertEqual(str(comp.type_signature), '') - - def test_returns_tuple_federated_map_with_two_values_unnamed(self): - value_type = computation_types.FederatedType( - (np.int32, np.int32), placements.CLIENTS - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - golden.check_string( - 'tuple_federated_map_with_two_values_unnamed.expected', - comp.formatted_representation(), - ) - self.assertEqual( - str(comp.type_signature), '<{int32}@CLIENTS,{int32}@CLIENTS>' - ) - - def test_returns_tuple_federated_map_with_two_values_named(self): - type_signature = computation_types.StructType( - (('a', np.int32), ('b', np.int32)) - ) - value_type = computation_types.FederatedType( - type_signature, placements.CLIENTS - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - golden.check_string( - 'tuple_federated_map_with_two_values_named.expected', - comp.formatted_representation(), - ) - self.assertEqual( - str(comp.type_signature), '' - ) - - def test_returns_tuple_federated_map_with_two_values_different_typed(self): - value_type = computation_types.FederatedType( - (np.int32, np.bool_), placements.CLIENTS - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - golden.check_string( - 'tuple_federated_map_with_two_values_different_typed.expected', - comp.formatted_representation(), - ) - self.assertEqual( - str(comp.type_signature), '<{int32}@CLIENTS,{bool}@CLIENTS>' - ) - - def test_returns_tuple_federated_apply_with_one_value_unnamed(self): - value_type = computation_types.FederatedType((np.int32,), placements.SERVER) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - self.assertEqual( - comp.compact_representation(), - '(let value=v in arg[0]),value>)>)', - ) - self.assertEqual(str(comp.type_signature), '') - - def test_returns_tuple_federated_apply_with_one_value_named(self): - type_signature = computation_types.StructType((('a', np.int32),)) - value_type = computation_types.FederatedType( - type_signature, placements.SERVER - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - self.assertEqual( - comp.compact_representation(), - '(let value=v in arg[0]),value>)>)', - ) - self.assertEqual(str(comp.type_signature), '') - - def test_returns_tuple_federated_apply_with_two_values_unnamed(self): - value_type = computation_types.FederatedType( - (np.int32, np.int32), placements.SERVER - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - golden.check_string( - 'tuple_federated_apply_with_two_values_unnamed.expected', - comp.formatted_representation(), - ) - self.assertEqual(str(comp.type_signature), '') - - def test_returns_tuple_federated_apply_with_two_values_named(self): - type_signature = computation_types.StructType( - (('a', np.int32), ('b', np.int32)) - ) - value_type = computation_types.FederatedType( - type_signature, placements.SERVER - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - golden.check_string( - 'tuple_federated_apply_with_two_values_named.expected', - comp.formatted_representation(), - ) - self.assertEqual( - str(comp.type_signature), '' - ) - - def test_returns_tuple_federated_apply_with_two_values_different_typed(self): - value_type = computation_types.FederatedType( - (np.int32, np.bool_), placements.SERVER - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_unzip(value) - golden.check_string( - 'tuple_federated_apply_with_two_values_different_typed.expected', - comp.formatted_representation(), - ) - self.assertEqual(str(comp.type_signature), '') - - -class CreateFederatedValueTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_federated_value(None, placements.CLIENTS) - - def test_raises_type_error_with_none_placement(self): - value = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - building_block_factory.create_federated_value(value, None) - - def test_raises_type_error_with_unknown_placement(self): - value = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - building_block_factory.create_federated_value(value, 'unknown') - - def test_returns_federated_value_at_clients(self): - value = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - comp = building_block_factory.create_federated_value( - value, placements.CLIENTS - ) - self.assertEqual( - comp.compact_representation(), 'federated_value_at_clients(1)' - ) - self.assertEqual(str(comp.type_signature), 'int32@CLIENTS') - - def test_returns_federated_value_at_server(self): - value = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - comp = building_block_factory.create_federated_value( - value, placements.SERVER - ) - self.assertEqual( - comp.compact_representation(), 'federated_value_at_server(1)' - ) - self.assertEqual(str(comp.type_signature), 'int32@SERVER') - - -INT_AT_CLIENTS = computation_types.FederatedType(np.int32, placements.CLIENTS) -BOOL_AT_CLIENTS = computation_types.FederatedType(np.bool_, placements.CLIENTS) -INT_AT_SERVER = computation_types.FederatedType(np.int32, placements.SERVER) -BOOL_AT_SERVER = computation_types.FederatedType(np.bool_, placements.SERVER) - - -class CreateFederatedZipTest(parameterized.TestCase, absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaisesRegex(TypeError, 'found NoneType'): - building_block_factory.create_federated_zip(None) - - def test_raises_type_error_with_empty_value(self): - value_type = computation_types.StructType([]) - value = building_blocks.Reference('v', value_type) - with self.assertRaisesRegex(TypeError, 'at least one FederatedType'): - building_block_factory.create_federated_zip(value) - - @parameterized.named_parameters([ - ( - 'one_unnamed', - computation_types.StructType((INT_AT_CLIENTS,)), - computation_types.StructType((np.int32,)), - ), - ( - 'one_named', - computation_types.StructType((('a', INT_AT_CLIENTS),)), - computation_types.StructType((('a', np.int32),)), - ), - ( - 'two_unnamed', - computation_types.StructType((INT_AT_CLIENTS,) * 2), - computation_types.StructType((np.int32,) * 2), - ), - ( - 'two_named', - computation_types.StructType( - (('a', INT_AT_CLIENTS), ('b', INT_AT_CLIENTS)) - ), - computation_types.StructType((('a', np.int32), ('b', np.int32))), - ), - ( - 'different_typed', - computation_types.StructType((BOOL_AT_CLIENTS, INT_AT_CLIENTS)), - computation_types.StructType((np.bool_, np.int32)), - ), - ('three_tuple', (INT_AT_CLIENTS,) * 3, (np.int32, np.int32, np.int32)), - ( - 'three_dict', - collections.OrderedDict( - a=INT_AT_CLIENTS, b=INT_AT_CLIENTS, c=BOOL_AT_CLIENTS - ), - computation_types.StructType( - collections.OrderedDict(a=np.int32, b=np.int32, c=np.bool_) - ), - ), - ]) - def test_returns_zip_at_clients(self, value_type, expected_zipped_type): - expected_zipped_type = computation_types.FederatedType( - expected_zipped_type, placements.CLIENTS - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_zip(value) - self.assertEqual( - comp.formatted_representation(), 'federated_zip_at_clients(v)' - ) - type_test_utils.assert_types_equivalent( - expected_zipped_type, comp.type_signature - ) - - @parameterized.named_parameters([ - ( - 'one_unnamed', - computation_types.StructType((INT_AT_SERVER,)), - computation_types.StructType((np.int32,)), - ), - ( - 'one_named', - computation_types.StructType((('a', INT_AT_SERVER),)), - computation_types.StructType((('a', np.int32),)), - ), - ( - 'two_unnamed', - computation_types.StructType((INT_AT_SERVER,) * 2), - computation_types.StructType((np.int32,) * 2), - ), - ( - 'two_named', - computation_types.StructType( - (('a', INT_AT_SERVER), ('b', INT_AT_SERVER)) - ), - computation_types.StructType((('a', np.int32), ('b', np.int32))), - ), - ( - 'different_typed', - computation_types.StructType((BOOL_AT_SERVER, INT_AT_SERVER)), - computation_types.StructType((np.bool_, np.int32)), - ), - ('three_tuple', (INT_AT_SERVER,) * 3, (np.int32, np.int32, np.int32)), - ( - 'three_dict', - collections.OrderedDict( - a=INT_AT_SERVER, b=INT_AT_SERVER, c=BOOL_AT_SERVER - ), - computation_types.StructType( - collections.OrderedDict(a=np.int32, b=np.int32, c=np.bool_) - ), - ), - ]) - def test_returns_zip_at_server(self, value_type, expected_zipped_type): - expected_zipped_type = computation_types.FederatedType( - expected_zipped_type, placements.SERVER - ) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_federated_zip(value) - self.assertEqual( - comp.formatted_representation(), 'federated_zip_at_server(v)' - ) - type_test_utils.assert_types_equivalent( - expected_zipped_type, comp.type_signature - ) - - def test_flat_raises_type_error_with_inconsistent_placement(self): - client_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True - ) - server_type = computation_types.FederatedType( - np.int32, placements.SERVER, all_equal=True - ) - value_type = computation_types.StructType( - [('a', client_type), ('b', server_type)] - ) - value = building_blocks.Reference('v', value_type) - self.assertEqual( - value.type_signature.compact_representation(), - '', - ) - with self.assertRaisesRegex(TypeError, 'same placement'): - building_block_factory.create_federated_zip(value) - - def test_nested_raises_type_error_with_inconsistent_placement(self): - client_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True - ) - server_type = computation_types.FederatedType( - np.int32, placements.SERVER, all_equal=True - ) - tuple_type = computation_types.StructType( - [('c', server_type), ('d', server_type)] - ) - value_type = computation_types.StructType( - [('a', client_type), ('b', tuple_type)] - ) - value = building_blocks.Reference('v', value_type) - self.assertEqual( - value.type_signature.compact_representation(), - '>', - ) - with self.assertRaisesRegex(TypeError, 'same placement'): - building_block_factory.create_federated_zip(value) - - def test_flat_raises_type_error_with_unplaced(self): - client_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True - ) - value_type = computation_types.StructType( - [('a', client_type), ('b', np.int32)] - ) - value = building_blocks.Reference('v', value_type) - self.assertEqual( - value.type_signature.compact_representation(), - '', - ) - with self.assertRaises(TypeError): - building_block_factory.create_federated_zip(value) - - def test_nested_raises_type_error_with_unplaced(self): - client_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True - ) - tuple_type = computation_types.StructType( - [('c', np.int32), ('d', np.int32)] - ) - value_type = computation_types.StructType( - [('a', client_type), ('b', tuple_type)] - ) - value = building_blocks.Reference('v', value_type) - self.assertEqual( - value.type_signature.compact_representation(), - '>', - ) - with self.assertRaises(TypeError): - building_block_factory.create_federated_zip(value) - - -class CreateSequenceMapTest(absltest.TestCase): - - def test_raises_type_error_with_none_fn(self): - arg_type = computation_types.SequenceType(np.int32) - arg = building_blocks.Reference('y', arg_type) - with self.assertRaises(TypeError): - building_block_factory.create_sequence_map(None, arg) - - def test_raises_type_error_with_none_arg(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - with self.assertRaises(TypeError): - building_block_factory.create_sequence_map(fn, None) - - def test_returns_sequence_map(self): - ref = building_blocks.Reference('x', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - arg_type = computation_types.SequenceType(np.int32) - arg = building_blocks.Reference('y', arg_type) - comp = building_block_factory.create_sequence_map(fn, arg) - self.assertEqual( - comp.compact_representation(), 'sequence_map(<(x -> x),y>)' - ) - self.assertEqual(str(comp.type_signature), 'int32*') - - -class CreateSequenceReduceTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - zero = building_blocks.Reference('z', np.int32) - op_type = computation_types.StructType((np.int32, np.int32)) - op_result = building_blocks.Reference('o', np.int32) - op = building_blocks.Lambda('x', op_type, op_result) - with self.assertRaises(TypeError): - building_block_factory.create_sequence_reduce(None, zero, op) - - def test_raises_type_error_with_none_zero(self): - value_type = computation_types.SequenceType(np.int32) - value = building_blocks.Reference('v', value_type) - op_type = computation_types.StructType((np.int32, np.int32)) - op_result = building_blocks.Reference('o', np.int32) - op = building_blocks.Lambda('x', op_type, op_result) - with self.assertRaises(TypeError): - building_block_factory.create_sequence_reduce(value, None, op) - - def test_raises_type_error_with_none_op(self): - value_type = computation_types.SequenceType(np.int32) - value = building_blocks.Reference('v', value_type) - zero = building_blocks.Reference('z', np.int32) - with self.assertRaises(TypeError): - building_block_factory.create_sequence_reduce(value, zero, None) - - def test_returns_sequence_reduce(self): - value_type = computation_types.SequenceType(np.int32) - value = building_blocks.Reference('v', value_type) - zero = building_blocks.Reference('z', np.int32) - op_type = computation_types.StructType((np.int32, np.int32)) - op_result = building_blocks.Reference('o', np.int32) - op = building_blocks.Lambda('x', op_type, op_result) - comp = building_block_factory.create_sequence_reduce(value, zero, op) - self.assertEqual( - comp.compact_representation(), 'sequence_reduce( o)>)' - ) - self.assertEqual(str(comp.type_signature), 'int32') - - -class CreateSequenceSumTest(absltest.TestCase): - - def test_raises_type_error_with_none_value(self): - with self.assertRaises(TypeError): - building_block_factory.create_sequence_sum(None) - - def test_returns_federated_sum(self): - value_type = computation_types.SequenceType(np.int32) - value = building_blocks.Reference('v', value_type) - comp = building_block_factory.create_sequence_sum(value) - self.assertEqual(comp.compact_representation(), 'sequence_sum(v)') - self.assertEqual(str(comp.type_signature), 'int32') - - -class CreateNamedTupleTest(absltest.TestCase): - - def test_raises_type_error_with_none_comp(self): - with self.assertRaises(TypeError): - building_block_factory.create_named_tuple(None, ('a',)) - - def test_raises_type_error_with_wrong_comp_type(self): - comp = building_blocks.Reference('data', np.int32) - with self.assertRaises(TypeError): - building_block_factory.create_named_tuple(comp, ('a',)) - - def test_raises_type_error_with_wrong_names_type_string(self): - type_signature = computation_types.StructType((np.int32, np.int32)) - comp = building_blocks.Reference('data', type_signature) - with self.assertRaises(TypeError): - building_block_factory.create_named_tuple(comp, 'a') - - def test_raises_type_error_with_wrong_names_type_ints(self): - type_signature = computation_types.StructType((np.int32, np.int32)) - comp = building_blocks.Reference('data', type_signature) - with self.assertRaises(TypeError): - building_block_factory.create_named_tuple(comp, 'a') - - def test_raises_value_error_with_wrong_lengths(self): - type_signature = computation_types.StructType((np.int32, np.int32)) - comp = building_blocks.Reference('data', type_signature) - with self.assertRaises(ValueError): - building_block_factory.create_named_tuple(comp, ('a',)) - - def test_creates_named_tuple_from_unamed_tuple(self): - type_signature = computation_types.StructType((np.int32, np.int32)) - comp = building_blocks.Reference('data', type_signature) - named_comp = building_block_factory.create_named_tuple(comp, ('a', 'b')) - expected_type_signature = computation_types.StructType( - (('a', np.int32), ('b', np.int32)) - ) - self.assertEqual(named_comp.type_signature, expected_type_signature) - - def test_creates_named_tuple_from_named_tuple(self): - type_signature = computation_types.StructType( - (('a', np.int32), ('b', np.int32)) - ) - comp = building_blocks.Reference('data', type_signature) - named_comp = building_block_factory.create_named_tuple(comp, ('c', 'd')) - expected_type_signature = computation_types.StructType( - (('c', np.int32), ('d', np.int32)) - ) - self.assertEqual(named_comp.type_signature, expected_type_signature) - - -def identity_for_type( - input_type: computation_types.Type, -) -> building_blocks.Lambda: - """Returns an identity computation for the provided `input_type`.""" - return building_blocks.Lambda( - 'x', input_type, building_blocks.Reference('x', input_type) - ) - - -class SelectOutputFromLambdaTest(absltest.TestCase): - - def test_raises_on_non_str_int_index(self): - lam = identity_for_type(computation_types.StructType([np.int32])) - with self.assertRaisesRegex(TypeError, 'Invalid selection type'): - building_block_factory.select_output_from_lambda(lam, [dict()]) - - def test_selects_single_output(self): - input_type = computation_types.StructType([np.int32, np.float32]) - lam = identity_for_type(input_type) - zero_selected = building_block_factory.select_output_from_lambda(lam, 0) - type_test_utils.assert_types_equivalent( - zero_selected.type_signature.parameter, lam.type_signature.parameter - ) - type_test_utils.assert_types_equivalent( - zero_selected.type_signature.result, lam.type_signature.result[0] - ) - self.assertEqual(str(zero_selected), '(x -> x[0])') - - def test_selects_single_output_by_str(self): - input_type = computation_types.StructType([('a', np.int32)]) - lam = identity_for_type(input_type) - selected = building_block_factory.select_output_from_lambda(lam, 'a') - type_test_utils.assert_types_equivalent( - selected.type_signature, - computation_types.FunctionType( - lam.parameter_type, lam.type_signature.result['a'] - ), - ) - - def test_selects_from_struct_by_removing_struct_wrapper(self): - lam = building_blocks.Lambda( - 'x', - np.int32, - building_blocks.Struct([building_blocks.Reference('x', np.int32)]), - ) - selected = building_block_factory.select_output_from_lambda(lam, 0) - type_test_utils.assert_types_equivalent( - selected.type_signature.result, computation_types.TensorType(np.int32) - ) - self.assertEqual(str(selected), '(x -> x)') - - def test_selects_struct_of_outputs(self): - input_type = computation_types.StructType([np.int32, np.int64, np.float32]) - lam = identity_for_type(input_type) - tuple_selected = building_block_factory.select_output_from_lambda( - lam, [0, 1] - ) - type_test_utils.assert_types_equivalent( - tuple_selected.type_signature.parameter, lam.type_signature.parameter - ) - type_test_utils.assert_types_equivalent( - tuple_selected.type_signature.result, - computation_types.StructType( - [lam.type_signature.result[0], lam.type_signature.result[1]] - ), - ) - self.assertEqual( - str(tuple_selected), '(x -> (let _var1=x in <_var1[0],_var1[1]>))' - ) - - def test_selects_struct_of_outputs_by_str_name(self): - input_type = computation_types.StructType( - [('a', np.int32), ('b', np.int64), ('c', np.float32)] - ) - lam = identity_for_type(input_type) - selected = building_block_factory.select_output_from_lambda(lam, ['a', 'b']) - type_test_utils.assert_types_equivalent( - selected.type_signature, - computation_types.FunctionType( - lam.parameter_type, - computation_types.StructType( - [lam.type_signature.result.a, lam.type_signature.result.b] - ), - ), - ) - - def test_selects_nested_federated_outputs(self): - input_type = computation_types.StructType([ - ('a', computation_types.StructType([('inner', np.int32)])), - ('b', np.int32), - ]) - lam = identity_for_type(input_type) - tuple_selected = building_block_factory.select_output_from_lambda( - lam, [('a', 'inner'), 'b'] - ) - type_test_utils.assert_types_equivalent( - tuple_selected.type_signature.parameter, lam.type_signature.parameter - ) - type_test_utils.assert_types_equivalent( - tuple_selected.type_signature.result, - computation_types.StructType( - [lam.type_signature.result.a.inner, lam.type_signature.result.b] - ), - ) - self.assertEqual( - str(tuple_selected), '(x -> (let _var1=x in <_var1.a.inner,_var1.b>))' - ) - - -class ZipUpToTest(absltest.TestCase): - - def test_zips_struct_of_federated_values(self): - comp = building_blocks.Struct([ - building_blocks.Reference( - 'x', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - building_blocks.Reference( - 'y', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]) - zippable_type = computation_types.FederatedType( - computation_types.StructType([(None, np.int32), (None, np.int32)]), - placements.CLIENTS, - ) - zipped = building_block_factory.zip_to_match_type( - comp_to_zip=comp, target_type=zippable_type - ) - type_test_utils.assert_types_equivalent( - zipped.type_signature, zippable_type - ) - - def test_does_not_zip_different_placement_target(self): - comp = building_blocks.Struct([ - building_blocks.Reference( - 'x', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - building_blocks.Reference( - 'y', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]) - non_zippable_type = computation_types.FederatedType( - computation_types.StructType([(None, np.int32), (None, np.int32)]), - placements.SERVER, - ) - zipped = building_block_factory.zip_to_match_type( - comp_to_zip=comp, target_type=non_zippable_type - ) - self.assertIsNone(zipped) - - def test_zips_struct_of_federated_values_under_struct(self): - comp = building_blocks.Struct([ - building_blocks.Struct([ - building_blocks.Reference( - 'x', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - building_blocks.Reference( - 'y', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]) - ]) - zippable_type = computation_types.StructType([( - None, - computation_types.FederatedType( - computation_types.StructType([(None, np.int32), (None, np.int32)]), - placements.CLIENTS, - ), - )]) - zipped = building_block_factory.zip_to_match_type( - comp_to_zip=comp, target_type=zippable_type - ) - type_test_utils.assert_types_equivalent( - zipped.type_signature, zippable_type - ) - - def test_assignability_with_names(self): - # This would correspond to an implicit downcast in TFF's typesystem; the - # result would not be assignable to the requested type. - comp = building_blocks.Struct([ - building_blocks.Struct([ - ( - 'a', - building_blocks.Reference( - 'x', - computation_types.FederatedType( - np.int32, placements.CLIENTS - ), - ), - ), - ( - 'b', - building_blocks.Reference( - 'y', - computation_types.FederatedType( - np.int32, placements.CLIENTS - ), - ), - ), - ]) - ]) - unnamed_zippable_type = computation_types.StructType([( - None, - computation_types.FederatedType( - computation_types.StructType([(None, np.int32), (None, np.int32)]), - placements.CLIENTS, - ), - )]) - named_zippable_type = computation_types.StructType([( - None, - computation_types.FederatedType( - computation_types.StructType([('a', np.int32), ('b', np.int32)]), - placements.CLIENTS, - ), - )]) - - not_zipped = building_block_factory.zip_to_match_type( - comp_to_zip=comp, target_type=unnamed_zippable_type - ) - zipped = building_block_factory.zip_to_match_type( - comp_to_zip=comp, target_type=named_zippable_type - ) - - self.assertFalse( - unnamed_zippable_type.is_assignable_from(named_zippable_type) - ) - - self.assertIsNone(not_zipped) - - type_test_utils.assert_types_equivalent( - zipped.type_signature, named_zippable_type - ) - - def test_does_not_zip_under_function(self): - result_comp = building_blocks.Struct([ - building_blocks.Reference( - 'x', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - building_blocks.Reference( - 'y', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]) - lam = building_blocks.Lambda(None, None, result_comp) - zippable_function_type = computation_types.FunctionType( - None, - computation_types.FederatedType( - computation_types.StructType([(None, np.int32), (None, np.int32)]), - placements.CLIENTS, - ), - ) - - zipped = building_block_factory.zip_to_match_type( - comp_to_zip=lam, target_type=zippable_function_type - ) - - self.assertIsNone(zipped) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/constructs_correct_computation_clients.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/constructs_correct_computation_clients.expected deleted file mode 100644 index 94b834ddb8..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/constructs_correct_computation_clients.expected +++ /dev/null @@ -1,10 +0,0 @@ -federated_map(< - (let - value_comp_placeholder=x - in (lambda_arg -> < - a=value_comp_placeholder, - lambda_arg[1], - b=lambda_arg[2] - >)), - federated_comp ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/constructs_correct_computation_server.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/constructs_correct_computation_server.expected deleted file mode 100644 index 331591dde7..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/constructs_correct_computation_server.expected +++ /dev/null @@ -1,10 +0,0 @@ -federated_apply(< - (let - value_comp_placeholder=x - in (lambda_arg -> < - a=value_comp_placeholder, - lambda_arg[1], - b=lambda_arg[2] - >)), - federated_comp ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/replaces_single_element.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/replaces_single_element.expected deleted file mode 100644 index 39ce229921..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/replaces_single_element.expected +++ /dev/null @@ -1,6 +0,0 @@ -(let - value_comp_placeholder=x - in (lambda_arg -> < - a=value_comp_placeholder, - b=lambda_arg[1] ->)) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/skips_unnamed_element.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/skips_unnamed_element.expected deleted file mode 100644 index ca0e3225e1..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/skips_unnamed_element.expected +++ /dev/null @@ -1,7 +0,0 @@ -(let - value_comp_placeholder=x - in (lambda_arg -> < - a=value_comp_placeholder, - lambda_arg[1], - b=lambda_arg[2] ->)) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected deleted file mode 100644 index 7289692884..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_different_typed.expected +++ /dev/null @@ -1,12 +0,0 @@ -(let - value=v - in < - federated_apply(< - (arg -> arg[0]), - value - >), - federated_apply(< - (arg -> arg[1]), - value - >) ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_named.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_named.expected deleted file mode 100644 index d8d9e701ac..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_named.expected +++ /dev/null @@ -1,12 +0,0 @@ -(let - value=v - in < - a=federated_apply(< - (arg -> arg[0]), - value - >), - b=federated_apply(< - (arg -> arg[1]), - value - >) ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_unnamed.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_unnamed.expected deleted file mode 100644 index 7289692884..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_apply_with_two_values_unnamed.expected +++ /dev/null @@ -1,12 +0,0 @@ -(let - value=v - in < - federated_apply(< - (arg -> arg[0]), - value - >), - federated_apply(< - (arg -> arg[1]), - value - >) ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_different_typed.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_different_typed.expected deleted file mode 100644 index e8ba4cd327..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_different_typed.expected +++ /dev/null @@ -1,12 +0,0 @@ -(let - value=v - in < - federated_map(< - (arg -> arg[0]), - value - >), - federated_map(< - (arg -> arg[1]), - value - >) ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected deleted file mode 100644 index 5ccbc07f75..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_named.expected +++ /dev/null @@ -1,12 +0,0 @@ -(let - value=v - in < - a=federated_map(< - (arg -> arg[0]), - value - >), - b=federated_map(< - (arg -> arg[1]), - value - >) ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected deleted file mode 100644 index e8ba4cd327..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/tuple_federated_map_with_two_values_unnamed.expected +++ /dev/null @@ -1,12 +0,0 @@ -(let - value=v - in < - federated_map(< - (arg -> arg[0]), - value - >), - federated_map(< - (arg -> arg[1]), - value - >) ->) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_reference.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_reference.expected deleted file mode 100644 index 7898192261..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_reference.expected +++ /dev/null @@ -1 +0,0 @@ -a diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_tuple_named.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_tuple_named.expected deleted file mode 100644 index 504b29be33..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_tuple_named.expected +++ /dev/null @@ -1,12 +0,0 @@ -< - g=< - d=a, - e=b, - f=c - >, - h=< - d=a, - e=b, - f=c - > -> diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_tuple_unnamed.expected b/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_tuple_unnamed.expected deleted file mode 100644 index 9bfe231e55..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_block_factory_test_goldens/zips_tuple_unnamed.expected +++ /dev/null @@ -1,12 +0,0 @@ -< - < - a, - b, - c - >, - < - a, - b, - c - > -> diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py b/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py index c96e1e4962..da6dae9bd9 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py +++ b/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py @@ -15,19 +15,15 @@ from typing import Union +import federated_language import numpy as np from google.protobuf import any_pb2 -from tensorflow_federated.python.core.impl.compiler import array -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def create_any_proto_from_array(value: np.ndarray): """Creates an `Any` proto for the given `np.array` value.""" - test_proto = array.to_proto(value) + test_proto = federated_language.array_to_proto(value) any_proto = any_pb2.Any() any_proto.Pack(test_proto) return any_proto @@ -51,10 +47,10 @@ def create_chained_calls(functions, arg): Args: functions: A Python list of functional computations. - arg: A `building_blocks.ComputationBuildingBlock`. + arg: A `federated_language.framework.ComputationBuildingBlock`. Returns: - A `building_blocks.Call`. + A `federated_language.framework.Call`. """ for fn in functions: if not fn.parameter_type.is_assignable_from(arg.type_signature): @@ -64,7 +60,7 @@ def create_chained_calls(functions, arg): str(fn.parameter_type), str(arg.type_signature) ) ) - call = building_blocks.Call(fn, arg) + call = federated_language.framework.Call(fn, arg) arg = call return call @@ -83,8 +79,10 @@ def create_whimsy_block( variable_name: The name of the variable. variable_type: The type of the variable. """ - ref = building_blocks.Literal(1, computation_types.TensorType(variable_type)) - return building_blocks.Block([(variable_name, ref)], comp) + ref = federated_language.framework.Literal( + 1, federated_language.TensorType(variable_type) + ) + return federated_language.framework.Block([(variable_name, ref)], comp) def create_whimsy_called_intrinsic(parameter_name, parameter_type=np.int32): @@ -98,12 +96,14 @@ def create_whimsy_called_intrinsic(parameter_name, parameter_type=np.int32): parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ - intrinsic_type = computation_types.FunctionType( + intrinsic_type = federated_language.FunctionType( parameter_type, parameter_type ) - intrinsic = building_blocks.Intrinsic('intrinsic', intrinsic_type) - ref = building_blocks.Reference(parameter_name, parameter_type) - return building_blocks.Call(intrinsic, ref) + intrinsic = federated_language.framework.Intrinsic( + 'intrinsic', intrinsic_type + ) + ref = federated_language.framework.Reference(parameter_name, parameter_type) + return federated_language.framework.Call(intrinsic, ref) def create_whimsy_called_federated_aggregate( @@ -128,25 +128,28 @@ def create_whimsy_called_federated_aggregate( report_parameter_name: The name of the report parameter. value_type: The TFF type of the value to be aggregated, placed at CLIENTS. """ - tensor_type = computation_types.TensorType(value_type) - value = building_block_factory.create_federated_value( - building_blocks.Literal(1, tensor_type), placements.CLIENTS + tensor_type = federated_language.TensorType(value_type) + value = federated_language.framework.create_federated_value( + federated_language.framework.Literal(1, tensor_type), + federated_language.CLIENTS, ) - literal_block = building_blocks.Literal(1, tensor_type) + literal_block = federated_language.framework.Literal(1, tensor_type) zero = literal_block - accumulate_type = computation_types.StructType((value_type, value_type)) + accumulate_type = federated_language.StructType((value_type, value_type)) accumulate_result = literal_block - accumulate = building_blocks.Lambda( + accumulate = federated_language.framework.Lambda( accumulate_parameter_name, accumulate_type, accumulate_result ) - merge_type = computation_types.StructType((value_type, value_type)) + merge_type = federated_language.StructType((value_type, value_type)) merge_result = literal_block - merge = building_blocks.Lambda(merge_parameter_name, merge_type, merge_result) + merge = federated_language.framework.Lambda( + merge_parameter_name, merge_type, merge_result + ) report_result = literal_block - report = building_blocks.Lambda( + report = federated_language.framework.Lambda( report_parameter_name, value_type, report_result ) - return building_block_factory.create_federated_aggregate( + return federated_language.framework.create_federated_aggregate( value, zero, accumulate, merge, report ) @@ -170,11 +173,12 @@ def create_whimsy_called_federated_apply( """ value = parameter_type(1) fn = create_identity_function(parameter_name, parameter_type) - arg_type = computation_types.TensorType(parameter_type) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(value, arg_type), placement=placements.SERVER + arg_type = federated_language.TensorType(parameter_type) + arg = federated_language.framework.create_federated_value( + federated_language.framework.Literal(value, arg_type), + placement=federated_language.SERVER, ) - return building_block_factory.create_federated_apply(fn, arg) + return federated_language.framework.create_federated_apply(fn, arg) def create_whimsy_called_federated_broadcast( @@ -190,11 +194,12 @@ def create_whimsy_called_federated_broadcast( value_type: The type of the value. """ value = value_type(1) - tensor_type = computation_types.TensorType(value_type) - value = building_block_factory.create_federated_value( - building_blocks.Literal(value, tensor_type), placement=placements.SERVER + tensor_type = federated_language.TensorType(value_type) + value = federated_language.framework.create_federated_value( + federated_language.framework.Literal(value, tensor_type), + placement=federated_language.SERVER, ) - return building_block_factory.create_federated_broadcast(value) + return federated_language.framework.create_federated_broadcast(value) def create_whimsy_called_federated_map( @@ -216,18 +221,19 @@ def create_whimsy_called_federated_map( """ value = parameter_type(1) fn = create_identity_function(parameter_name, parameter_type) - arg_type = computation_types.TensorType(parameter_type) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(value, arg_type), placement=placements.CLIENTS + arg_type = federated_language.TensorType(parameter_type) + arg = federated_language.framework.create_federated_value( + federated_language.framework.Literal(value, arg_type), + placement=federated_language.CLIENTS, ) # TODO: b/338284242 - Replace this with a `Data` block once the compiler tests # do not use string equality. # pylint: disable=protected-access - arg._type_signature = computation_types.FederatedType( - arg_type, placements.CLIENTS, all_equal=False + arg._type_signature = federated_language.FederatedType( + arg_type, federated_language.CLIENTS, all_equal=False ) # pylint: enable=protected-access - return building_block_factory.create_federated_map(fn, arg) + return federated_language.framework.create_federated_map(fn, arg) def create_whimsy_called_federated_map_all_equal( @@ -249,11 +255,12 @@ def create_whimsy_called_federated_map_all_equal( """ value = parameter_type(1) fn = create_identity_function(parameter_name, parameter_type) - arg_type = computation_types.TensorType(parameter_type) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(value, arg_type), placement=placements.CLIENTS + arg_type = federated_language.TensorType(parameter_type) + arg = federated_language.framework.create_federated_value( + federated_language.framework.Literal(value, arg_type), + placement=federated_language.CLIENTS, ) - return building_block_factory.create_federated_map_all_equal(fn, arg) + return federated_language.framework.create_federated_map_all_equal(fn, arg) def create_whimsy_called_federated_mean( @@ -262,21 +269,21 @@ def create_whimsy_called_federated_mean( ): """Returns a called federated mean.""" value = value_type(1) - value_type = computation_types.TensorType(value_type) - values = building_block_factory.create_federated_value( - building_blocks.Literal(value, value_type), - placement=placements.CLIENTS, + value_type = federated_language.TensorType(value_type) + values = federated_language.framework.create_federated_value( + federated_language.framework.Literal(value, value_type), + placement=federated_language.CLIENTS, ) if weights_type is not None: weights_value = weights_type(1) - weights_type = computation_types.TensorType(weights_type) - weights = building_block_factory.create_federated_value( - building_blocks.Literal(weights_value, weights_type), - placement=placements.CLIENTS, + weights_type = federated_language.TensorType(weights_type) + weights = federated_language.framework.create_federated_value( + federated_language.framework.Literal(weights_value, weights_type), + placement=federated_language.CLIENTS, ) else: weights = None - return building_block_factory.create_federated_mean(values, weights) + return federated_language.framework.create_federated_mean(values, weights) def create_whimsy_called_federated_secure_sum_bitwidth( @@ -292,13 +299,13 @@ def create_whimsy_called_federated_secure_sum_bitwidth( value_type: The type of the value. """ lit_value = value_type(1) - tensor_type = computation_types.TensorType(value_type) - value = building_block_factory.create_federated_value( - building_blocks.Literal(lit_value, tensor_type), - placement=placements.CLIENTS, + tensor_type = federated_language.TensorType(value_type) + value = federated_language.framework.create_federated_value( + federated_language.framework.Literal(lit_value, tensor_type), + placement=federated_language.CLIENTS, ) - bitwidth = building_blocks.Literal(lit_value, tensor_type) - return building_block_factory.create_federated_secure_sum_bitwidth( + bitwidth = federated_language.framework.Literal(lit_value, tensor_type) + return federated_language.framework.create_federated_secure_sum_bitwidth( value, bitwidth ) @@ -316,11 +323,12 @@ def create_whimsy_called_federated_sum( value_type: The type of the value. """ value = value_type(1) - tensor_type = computation_types.TensorType(value_type) - value = building_block_factory.create_federated_value( - building_blocks.Literal(value, tensor_type), placement=placements.CLIENTS + tensor_type = federated_language.TensorType(value_type) + value = federated_language.framework.create_federated_value( + federated_language.framework.Literal(value, tensor_type), + placement=federated_language.CLIENTS, ) - return building_block_factory.create_federated_sum(value) + return federated_language.framework.create_federated_sum(value) def create_whimsy_called_sequence_map( @@ -342,20 +350,20 @@ def create_whimsy_called_sequence_map( any_proto: The any proto to use for the data block. """ fn = create_identity_function(parameter_name, parameter_type) - arg_type = computation_types.SequenceType(parameter_type) - arg = building_blocks.Data(any_proto, arg_type) - return building_block_factory.create_sequence_map(fn, arg) + arg_type = federated_language.SequenceType(parameter_type) + arg = federated_language.framework.Data(any_proto, arg_type) + return federated_language.framework.create_sequence_map(fn, arg) def create_whimsy_called_federated_value( - placement: placements.PlacementLiteral, + placement: federated_language.framework.PlacementLiteral, value_type: type[np.generic] = np.int32, ): value = value_type(1) - value = building_blocks.Literal( - value, computation_types.TensorType(value_type) + value = federated_language.framework.Literal( + value, federated_language.TensorType(value_type) ) - return building_block_factory.create_federated_value(value, placement) + return federated_language.framework.create_federated_value(value, placement) def create_identity_block(variable_name, comp): @@ -369,8 +377,10 @@ def create_identity_block(variable_name, comp): variable_name: The name of the variable. comp: The computation to use as the variable. """ - ref = building_blocks.Reference(variable_name, comp.type_signature) - return building_blocks.Block([(variable_name, comp)], ref) + ref = federated_language.framework.Reference( + variable_name, comp.type_signature + ) + return federated_language.framework.Block([(variable_name, comp)], ref) def create_identity_block_with_whimsy_ref( @@ -387,8 +397,8 @@ def create_identity_block_with_whimsy_ref( variable_type: The type of the variable. """ value = variable_type(1) - literal = building_blocks.Literal( - value, computation_types.TensorType(variable_type) + literal = federated_language.framework.Literal( + value, federated_language.TensorType(variable_type) ) return create_identity_block(variable_name, literal) @@ -404,8 +414,8 @@ def create_identity_function(parameter_name, parameter_type=np.int32): parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ - ref = building_blocks.Reference(parameter_name, parameter_type) - return building_blocks.Lambda(ref.name, ref.type_signature, ref) + ref = federated_language.framework.Reference(parameter_name, parameter_type) + return federated_language.framework.Lambda(ref.name, ref.type_signature, ref) def create_lambda_to_whimsy_called_intrinsic( @@ -426,7 +436,9 @@ def create_lambda_to_whimsy_called_intrinsic( call = create_whimsy_called_intrinsic( parameter_name=parameter_name, parameter_type=parameter_type ) - return building_blocks.Lambda(parameter_name, parameter_type, call) + return federated_language.framework.Lambda( + parameter_name, parameter_type, call + ) def create_nested_syntax_tree(): @@ -441,7 +453,7 @@ def create_nested_syntax_tree(): parameter*, so that if we were actually executing this call the argument will be thrown away. - All leaf nodes are instances of `building_blocks.Lit`. + All leaf nodes are instances of `federated_language.framework.Lit`. Call / \ @@ -497,38 +509,46 @@ def create_nested_syntax_tree(): [arg, y, z, t, u, v, x, w] Returns: - An instance of `building_blocks.ComputationBuildingBlock` + An instance of `federated_language.framework.ComputationBuildingBlock` satisfying the description above. """ - tensor_type = computation_types.TensorType(np.int32) - lit_c = building_blocks.Literal(3, tensor_type) - lit_d = building_blocks.Literal(4, tensor_type) - left_most_leaf = building_blocks.Block([('t', lit_c)], lit_d) - - lit_e = building_blocks.Literal(5, tensor_type) - lit_f = building_blocks.Literal(6, tensor_type) - center_leaf = building_blocks.Block([('u', lit_e)], lit_f) - inner_tuple = building_blocks.Struct([left_most_leaf, center_leaf]) - - selected = building_blocks.Selection(inner_tuple, index=0) - lit_g = building_blocks.Literal(7, tensor_type) - middle_block = building_blocks.Block([('v', selected)], lit_g) - - lit_i = building_blocks.Literal(8, tensor_type) - lit_j = building_blocks.Literal(9, tensor_type) - right_most_endpoint = building_blocks.Block([('w', lit_i)], lit_j) - - lit_h = building_blocks.Literal(10, tensor_type) - right_child = building_blocks.Block([('x', lit_h)], right_most_endpoint) - - result = building_blocks.Struct([middle_block, right_child]) - lit_a = building_blocks.Literal(1, tensor_type) - lit_b = building_blocks.Literal(2, tensor_type) - whimsy_outer_block = building_blocks.Block( + tensor_type = federated_language.TensorType(np.int32) + lit_c = federated_language.framework.Literal(3, tensor_type) + lit_d = federated_language.framework.Literal(4, tensor_type) + left_most_leaf = federated_language.framework.Block([('t', lit_c)], lit_d) + + lit_e = federated_language.framework.Literal(5, tensor_type) + lit_f = federated_language.framework.Literal(6, tensor_type) + center_leaf = federated_language.framework.Block([('u', lit_e)], lit_f) + inner_tuple = federated_language.framework.Struct( + [left_most_leaf, center_leaf] + ) + + selected = federated_language.framework.Selection(inner_tuple, index=0) + lit_g = federated_language.framework.Literal(7, tensor_type) + middle_block = federated_language.framework.Block([('v', selected)], lit_g) + + lit_i = federated_language.framework.Literal(8, tensor_type) + lit_j = federated_language.framework.Literal(9, tensor_type) + right_most_endpoint = federated_language.framework.Block( + [('w', lit_i)], lit_j + ) + + lit_h = federated_language.framework.Literal(10, tensor_type) + right_child = federated_language.framework.Block( + [('x', lit_h)], right_most_endpoint + ) + + result = federated_language.framework.Struct([middle_block, right_child]) + lit_a = federated_language.framework.Literal(1, tensor_type) + lit_b = federated_language.framework.Literal(2, tensor_type) + whimsy_outer_block = federated_language.framework.Block( [('y', lit_a), ('z', lit_b)], result ) - whimsy_lambda = building_blocks.Lambda('arg', tensor_type, whimsy_outer_block) - whimsy_arg = building_blocks.Literal(11, tensor_type) - called_lambda = building_blocks.Call(whimsy_lambda, whimsy_arg) + whimsy_lambda = federated_language.framework.Lambda( + 'arg', tensor_type, whimsy_outer_block + ) + whimsy_arg = federated_language.framework.Literal(11, tensor_type) + called_lambda = federated_language.framework.Call(whimsy_lambda, whimsy_arg) return called_lambda diff --git a/tensorflow_federated/python/core/impl/compiler/building_blocks.py b/tensorflow_federated/python/core/impl/compiler/building_blocks.py deleted file mode 100644 index f55f1431f8..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_blocks.py +++ /dev/null @@ -1,1831 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of classes representing computations in a deserialized form.""" - -import abc -from collections.abc import Iterable, Iterator -import enum -import typing -from typing import Optional, Union -import zlib - -import numpy as np - -from google.protobuf import any_pb2 -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import array -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import typed_object - - -def _check_computation_oneof( - computation_proto: pb.Computation, - expected_oneof: str, -): - """Checks that `computation_proto` is a oneof of the expected variant.""" - computation_oneof = computation_proto.WhichOneof('computation') - if computation_oneof != expected_oneof: - raise ValueError( - f'Expected the computation to be a {expected_oneof}, found' - f' {computation_oneof}.' - ) - - -class UnexpectedBlockError(TypeError): - - def __init__( - self, - expected: type['ComputationBuildingBlock'], - actual: 'ComputationBuildingBlock', - ): - message = f'Expected block of kind {expected}, found block {actual}' - super().__init__(message) - self.actual = actual - self.expected = expected - - -class ComputationBuildingBlock(typed_object.TypedObject, metaclass=abc.ABCMeta): - """The abstract base class for abstractions in the TFF's internal language. - - Instances of this class correspond roughly one-to-one to the abstractions - defined in the `Computation` message in TFF's `computation.proto`, and are - intended primarily for the ease of manipulating the abstract syntax trees - (AST) of federated computations as they are transformed by TFF's compiler - pipeline to mold into the needs of a particular execution backend. The only - abstraction that does not have a dedicated Python equivalent is a section - of TensorFlow code (it's represented by `tff.framework.CompiledComputation`). - """ - - @classmethod - def from_proto( - cls, computation_proto: pb.Computation - ) -> 'ComputationBuildingBlock': - """Returns an instance of a derived class based on 'computation_proto'. - - Args: - computation_proto: An instance of pb.Computation. - - Returns: - An instance of a class that implements 'ComputationBuildingBlock' and - that contains the deserialized logic from in 'computation_proto'. - - Raises: - NotImplementedError: if computation_proto contains a kind of computation - for which deserialization has not been implemented yet. - ValueError: if deserialization failed due to the argument being invalid. - """ - py_typecheck.check_type(computation_proto, pb.Computation) - computation_oneof = computation_proto.WhichOneof('computation') - deserializer = _deserializer_dict.get(computation_oneof) - if deserializer is not None: - deserialized = deserializer(computation_proto) - type_spec = type_serialization.deserialize_type(computation_proto.type) - if not deserialized.type_signature.is_equivalent_to(type_spec): - raise ValueError( - 'The type {} derived from the computation structure does not ' - 'match the type {} declared in its signature'.format( - deserialized.type_signature, type_spec - ) - ) - return deserialized - else: - raise NotImplementedError( - 'Deserialization for computations of type {} has not been ' - 'implemented yet.'.format(computation_oneof) - ) - return deserializer(computation_proto) - - def __init__(self, type_spec): - """Constructs a computation building block with the given TFF type. - - Args: - type_spec: An instance of types.Type, or something convertible to it via - types.to_type(). - """ - type_signature = computation_types.to_type(type_spec) - self._type_signature = type_signature - self._cached_hash = None - self._cached_proto = None - - @property - def type_signature(self) -> computation_types.Type: - return self._type_signature - - @abc.abstractmethod - def children(self) -> Iterator['ComputationBuildingBlock']: - """Returns an iterator yielding immediate child building blocks.""" - raise NotImplementedError - - def compact_representation(self): - """Returns the compact string representation of this building block.""" - return _string_representation(self, formatted=False) - - def formatted_representation(self): - """Returns the formatted string representation of this building block.""" - return _string_representation(self, formatted=True) - - def structural_representation(self): - """Returns the structural string representation of this building block.""" - return _structural_representation(self) - - @property - def proto(self): - """Returns a serialized form of this object as a pb.Computation instance.""" - if self._cached_proto is None: - self._cached_proto = self._proto() - return self._cached_proto - - @abc.abstractmethod - def _proto(self): - """Uncached, internal version of `proto`.""" - raise NotImplementedError - - # TODO: b/113112885 - Add memoization after identifying a suitable externally - # available standard library that works in Python 2/3. - - @abc.abstractmethod - def __repr__(self): - """Returns a full-form representation of this computation building block.""" - raise NotImplementedError - - def __str__(self): - """Returns a concise representation of this computation building block.""" - return self.compact_representation() - - -class Reference(ComputationBuildingBlock): - """A reference to a name defined earlier in TFF's internal language. - - Names are defined by lambda expressions (which have formal named parameters), - and block structures (which can have one or more locals). The reference - construct is used to refer to those parameters or locals by a string name. - The usual hiding rules apply. A reference binds to the closest definition of - the given name in the most deeply nested surrounding lambda or block. - - A concise notation for a reference to name `foo` is `foo`. For example, in - a lambda expression `(x -> f(x))` there are two references, one to `x` that - is defined as the formal parameter of the lambda epxression, and one to `f` - that must have been defined somewhere in the surrounding context. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Reference': - _check_computation_oneof(computation_proto, 'reference') - return cls( - str(computation_proto.reference.name), - type_serialization.deserialize_type(computation_proto.type), - ) - - def __init__(self, name: str, type_spec: object, context=None): - """Creates a reference to 'name' of type 'type_spec' in context 'context'. - - Args: - name: The name of the referenced entity. - type_spec: The type spec of the referenced entity. - context: The optional context in which the referenced entity is defined. - This class does not prescribe what Python type the 'context' needs to be - and merely exposes it as a property (see below). The only requirement is - that the context implements str() and repr(). - - Raises: - TypeError: if the arguments are of the wrong types. - """ - py_typecheck.check_type(name, str) - super().__init__(type_spec) - self._name = name - self._context = context - - def _proto(self): - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - reference=pb.Reference(name=self._name), - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - del self - return iter(()) - - @property - def name(self) -> str: - return self._name - - @property - def context(self): - return self._context - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Reference): - return NotImplemented - # Important: References are only equal to each other if they are the same - # object because two references with the same `name` are different if they - # are in different locations within the same scope, in different scopes, or - # in different contexts. - return False - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((self._name, self._type_signature)) - return self._cached_hash - - def __repr__(self): - return "Reference('{}', {!r}{})".format( - self._name, - self.type_signature, - ', {!r}'.format(self._context) if self._context is not None else '', - ) - - -class Selection(ComputationBuildingBlock): - """A selection by name or index from a struct-typed value in TFF's language. - - The concise syntax for selections is `foo.bar` (selecting a named `bar` from - the value of expression `foo`), and `foo[n]` (selecting element at index `n` - from the value of `foo`). - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Selection': - _check_computation_oneof(computation_proto, 'selection') - selection = ComputationBuildingBlock.from_proto( - computation_proto.selection.source - ) - return cls(selection, index=computation_proto.selection.index) - - def __init__( - self, - source: ComputationBuildingBlock, - name: Optional[str] = None, - index: Optional[int] = None, - ): - """A selection from 'source' by a string or numeric 'name_or_index'. - - Exactly one of 'name' or 'index' must be specified (not None). - - Args: - source: The source value to select from (an instance of - ComputationBuildingBlock). - name: A string name of the element to be selected. - index: A numeric index of the element to be selected. - - Raises: - TypeError: if arguments are of the wrong types. - ValueError: if the name is empty or index is negative, or the name/index - is not compatible with the type signature of the source, or neither or - both are defined (not None). - """ - py_typecheck.check_type(source, ComputationBuildingBlock) - source_type = source.type_signature - # TODO: b/224484886 - Downcasting to all handled types. - source_type = typing.cast(Union[computation_types.StructType], source_type) - if not isinstance(source_type, computation_types.StructType): - raise TypeError( - 'Expected the source of selection to be a TFF struct, ' - 'instead found it to be of type {}.'.format(source_type) - ) - if name is not None and index is not None: - raise ValueError( - 'Cannot simultaneously specify a name and an index, choose one.' - ) - if name is not None: - py_typecheck.check_type(name, str) - if not name: - raise ValueError('The name of the selected element cannot be empty.') - # Normalize, in case we are dealing with a Unicode type or some such. - name = str(name) - if not structure.has_field(source_type, name): - raise ValueError( - f'Error selecting named field `{name}` from type `{source_type}`, ' - f'whose only named fields are {structure.name_list(source_type)}.' - ) - type_signature = source_type[name] - elif index is not None: - py_typecheck.check_type(index, int) - length = len(source_type) - if index < 0 or index >= length: - raise ValueError( - f'The index `{index}` does not fit into the valid range in the ' - f'struct type: 0..{length}' - ) - type_signature = source_type[index] - else: - raise ValueError( - 'Must define either a name or index, and neither was specified.' - ) - super().__init__(type_signature) - self._source = source - self._name = name - self._index = index - - def _proto(self): - selection = pb.Selection(source=self._source.proto, index=self.as_index()) - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - selection=selection, - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - yield self._source - - @property - def source(self) -> ComputationBuildingBlock: - return self._source - - @property - def name(self) -> Optional[str]: - return self._name - - @property - def index(self) -> Optional[int]: - return self._index - - def as_index(self) -> int: - if self._index is not None: - return self._index - else: - field_to_index = structure.name_to_index_map(self.source.type_signature) # pytype: disable=wrong-arg-types - return field_to_index[self._name] - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Selection): - return NotImplemented - return ( - self._source, - self._name, - self._index, - ) == ( - other._source, - other._name, - other._index, - ) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((self._source, self._name, self._index)) - return self._cached_hash - - def __repr__(self): - if self._name is not None: - return "Selection({!r}, name='{}')".format(self._source, self._name) - else: - return 'Selection({!r}, index={})'.format(self._source, self._index) - - -class Struct(ComputationBuildingBlock, structure.Struct): - """A struct with named or unnamed elements in TFF's internal language. - - The concise notation for structs is `` - for structs with named elements, `` for structs with - unnamed elements, or a mixture of these for structs with some named and some - unnamed elements, where `name_k` are the names, and `value_k` are the value - expressions. - - For example, a lambda expression that applies `fn` to elements of 2-structs - pointwise could be represented as `(arg -> )`. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Struct': - _check_computation_oneof(computation_proto, 'struct') - - def _element( - proto: pb.Struct.Element, - ) -> tuple[Optional[str], ComputationBuildingBlock]: - if proto.name: - name = str(proto.name) - else: - name = None - element = ComputationBuildingBlock.from_proto(proto.value) - return (name, element) - - elements = [_element(x) for x in computation_proto.struct.element] - return cls(elements) - - def __init__(self, elements, container_type=None): - """Constructs a struct from the given list of elements. - - Args: - elements: The elements of the struct, supplied as a list of (name, value) - pairs, where 'name' can be None in case the corresponding element is not - named and only accessible via an index (see also `structure.Struct`). - container_type: An optional Python container type to associate with the - struct. - - Raises: - TypeError: if arguments are of the wrong types. - """ - - # Not using super() here and below, as the two base classes have different - # signatures of their constructors, and the struct implementation - # of selection interfaces should override that in the generic class 'Value' - # to favor simplified expressions where simplification is possible. - def _map_element(e): - """Returns a named or unnamed element.""" - if isinstance(e, ComputationBuildingBlock): - return (None, e) - elif py_typecheck.is_name_value_pair( - e, value_type=ComputationBuildingBlock - ): - if e[0] is not None and not e[0]: - raise ValueError('Unexpected struct element with empty string name.') - return (e[0], e[1]) - else: - raise TypeError('Unexpected struct element: {}.'.format(e)) - - elements = [_map_element(e) for e in elements] - element_pairs = [ - ((e[0], e[1].type_signature) if e[0] else e[1].type_signature) - for e in elements - ] - - if container_type is None: - type_signature = computation_types.StructType(element_pairs) - else: - type_signature = computation_types.StructWithPythonType( - element_pairs, container_type - ) - ComputationBuildingBlock.__init__(self, type_signature) - structure.Struct.__init__(self, elements) - self._type_signature = type_signature - - @property - def type_signature(self) -> computation_types.StructType: - return self._type_signature - - def _proto(self): - elements = [] - for k, v in structure.iter_elements(self): - if k is not None: - element = pb.Struct.Element(name=k, value=v.proto) - else: - element = pb.Struct.Element(value=v.proto) - elements.append(element) - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - struct=pb.Struct(element=elements), - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - return (element for _, element in structure.iter_elements(self)) - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Struct): - return NotImplemented - if self._type_signature != other._type_signature: - return False - return structure.Struct.__eq__(self, other) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash(( - structure.Struct.__hash__(self), - self._type_signature, - )) - return self._cached_hash - - def __repr__(self): - def _element_repr(element): - name, value = element - name_repr = "'{}'".format(name) if name is not None else 'None' - return '({}, {!r})'.format(name_repr, value) - - return 'Struct([{}])'.format( - ', '.join(_element_repr(e) for e in structure.iter_elements(self)) - ) - - -class Call(ComputationBuildingBlock): - """A representation of a function invocation in TFF's internal language. - - The call construct takes an argument struct with two elements, the first being - the function to invoke (represented as a computation with a functional result - type), and the second being the argument to feed to that function. Typically, - the function is either a TFF instrinsic, or a lambda expression. - - The concise notation for calls is `foo(bar)`, where `foo` is the function, - and `bar` is the argument. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Call': - _check_computation_oneof(computation_proto, 'call') - fn = ComputationBuildingBlock.from_proto(computation_proto.call.function) - arg_proto = computation_proto.call.argument - if arg_proto.WhichOneof('computation') is not None: - arg = ComputationBuildingBlock.from_proto(arg_proto) - else: - arg = None - return cls(fn, arg) - - def __init__( - self, - fn: ComputationBuildingBlock, - arg: Optional[ComputationBuildingBlock] = None, - ): - """Creates a call to 'fn' with argument 'arg'. - - Args: - fn: A value of a functional type that represents the function to invoke. - arg: The optional argument, present iff 'fn' expects one, of a type that - matches the type of 'fn'. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - py_typecheck.check_type(fn, ComputationBuildingBlock) - if arg is not None: - py_typecheck.check_type(arg, ComputationBuildingBlock) - function_type = fn.type_signature - # TODO: b/224484886 - Downcasting to all handled types. - function_type = typing.cast( - Union[computation_types.FunctionType], function_type - ) - if not isinstance(function_type, computation_types.FunctionType): - raise TypeError( - f'Expected `fn` to have a `tff.FunctionType`, found {function_type}.' - ) - parameter_type = function_type.parameter - if parameter_type is not None: - if arg is None: - raise TypeError( - f'Expected `arg` to be of type {parameter_type}, found None.' - ) - elif not parameter_type.is_assignable_from(arg.type_signature): - raise TypeError( - f'Expected `arg` to be of type {parameter_type}, found an' - f' incompatible type {arg.type_signature}.' - ) - else: - if arg is not None: - raise TypeError(f'Expected `arg` to be None, found {arg}.') - super().__init__(function_type.result) - self._function = fn - self._argument = arg - - def _proto(self): - if self._argument is not None: - call = pb.Call( - function=self._function.proto, argument=self._argument.proto - ) - else: - call = pb.Call(function=self._function.proto) - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), call=call - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - yield self._function - if self._argument is not None: - yield self._argument - - @property - def function(self) -> ComputationBuildingBlock: - return self._function - - @property - def argument(self) -> Optional[ComputationBuildingBlock]: - return self._argument - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Call): - return NotImplemented - return ( - self._function, - self._argument, - ) == ( - other._function, - other._argument, - ) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((self._function, self._argument)) - return self._cached_hash - - def __repr__(self): - if self._argument is not None: - return 'Call({!r}, {!r})'.format(self._function, self._argument) - else: - return 'Call({!r})'.format(self._function) - - -class Lambda(ComputationBuildingBlock): - """A representation of a lambda expression in TFF's internal language. - - A lambda expression consists of a string formal parameter name, and a result - expression that can contain references by name to that formal parameter. A - concise notation for lambdas is `(foo -> bar)`, where `foo` is the name of - the formal parameter, and `bar` is the result expression. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Lambda': - _check_computation_oneof(computation_proto, 'lambda') - fn: pb.Lambda = getattr(computation_proto, 'lambda') - if computation_proto.type.function.HasField('parameter'): - parameter_type = type_serialization.deserialize_type( - computation_proto.type.function.parameter - ) - else: - parameter_type = None - result = ComputationBuildingBlock.from_proto(fn.result) - return cls(fn.parameter_name, parameter_type, result) - - def __init__( - self, - parameter_name: Optional[str], - parameter_type: Optional[object], - result: ComputationBuildingBlock, - ): - """Creates a lambda expression. - - Args: - parameter_name: The (string) name of the parameter accepted by the lambda. - This name can be used by Reference() instances in the body of the lambda - to refer to the parameter. Note that an empty parameter name shall be - treated as equivalent to no parameter. - parameter_type: The type of the parameter, an instance of types.Type or - something convertible to it by types.to_type(). - result: The resulting value produced by the expression that forms the body - of the lambda. Must be an instance of ComputationBuildingBlock. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - if not parameter_name: - parameter_name = None - if (parameter_name is None) != (parameter_type is None): - raise TypeError( - 'A lambda expression must have either a valid parameter name and ' - 'type or both parameter name and type must be `None`. ' - '`parameter_name` was {} but `parameter_type` was {}.'.format( - parameter_name, parameter_type - ) - ) - if parameter_name is not None: - py_typecheck.check_type(parameter_name, str) - parameter_type = computation_types.to_type(parameter_type) - py_typecheck.check_type(result, ComputationBuildingBlock) - type_signature = computation_types.FunctionType( - parameter_type, result.type_signature - ) - super().__init__(type_signature) - self._parameter_name = parameter_name - self._parameter_type = parameter_type - self._result = result - self._type_signature = type_signature - - @property - def type_signature(self) -> computation_types.FunctionType: - return self._type_signature - - def _proto(self) -> pb.Computation: - type_signature = type_serialization.serialize_type(self.type_signature) - fn = pb.Lambda( - parameter_name=self._parameter_name, result=self._result.proto - ) - # We are unpacking the lambda argument here because `lambda` is a reserved - # keyword in Python, but it is also the name of the parameter for a - # `pb.Computation`. - # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts - return pb.Computation(type=type_signature, **{'lambda': fn}) # pytype: disable=wrong-keyword-args - - def children(self) -> Iterator[ComputationBuildingBlock]: - yield self._result - - @property - def parameter_name(self) -> Optional[str]: - return self._parameter_name - - @property - def parameter_type(self) -> Optional[computation_types.Type]: - return self._parameter_type - - @property - def result(self) -> ComputationBuildingBlock: - return self._result - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Lambda): - return NotImplemented - return ( - self._parameter_name, - self._parameter_type, - self._result, - ) == ( - other._parameter_name, - other._parameter_type, - other._result, - ) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash(( - self._parameter_name, - self._parameter_type, - self._result, - )) - return self._cached_hash - - def __repr__(self) -> str: - return "Lambda('{}', {!r}, {!r})".format( - self._parameter_name, self._parameter_type, self._result - ) - - -class Block(ComputationBuildingBlock): - """A representation of a block of code in TFF's internal language. - - A block is a syntactic structure that consists of a sequence of local name - bindings followed by a result. The bindings are interpreted sequentially, - with bindings later in the sequence in the scope of those listed earlier, - and the result in the scope of the entire sequence. The usual hiding rules - apply. - - An informal concise notation for blocks is the following, with `name_k` - representing the names defined locally for the block, `value_k` the values - associated with them, and `result` being the expression that reprsents the - value of the block construct. - - ``` - let name_1=value_1, name_2=value_2, ..., name_n=value_n in result - ``` - - Blocks are technically a redundant abstraction, as they can be equally well - represented by lambda expressions. A block of the form `let x=y in z` is - roughly equivalent to `(x -> z)(y)`. Although redundant, blocks have a use - as a way to reduce TFF computation ASTs to a simpler, less nested and more - readable form, and are helpful in AST transformations as a mechanism that - prevents possible naming conflicts. - - An example use of a block expression to flatten a nested structure below: - - ``` - z = federated_sum(federated_map(x, federated_broadcast(y))) - ``` - - An equivalent form in a more sequential notation using a block expression: - ``` - let - v1 = federated_broadcast(y), - v2 = federated_map(x, v1) - in - federated_sum(v2) - ``` - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Block': - _check_computation_oneof(computation_proto, 'block') - return cls( - [ - (str(loc.name), ComputationBuildingBlock.from_proto(loc.value)) - for loc in computation_proto.block.local - ], - ComputationBuildingBlock.from_proto(computation_proto.block.result), - ) - - def __init__( - self, - local_symbols: Iterable[tuple[str, ComputationBuildingBlock]], - result: ComputationBuildingBlock, - ): - """Creates a block of TFF code. - - Args: - local_symbols: The list of one or more local declarations, each of which - is a 2-tuple (name, value), with 'name' being the string name of a local - symbol being defined, and 'value' being the instance of - ComputationBuildingBlock, the output of which will be locally bound to - that name. - result: An instance of ComputationBuildingBlock that computes the result. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - updated_locals = [] - for index, element in enumerate(local_symbols): - if ( - not isinstance(element, tuple) - or (len(element) != 2) - or not isinstance(element[0], str) - ): - raise TypeError( - 'Expected the locals to be a list of 2-element structs with string ' - 'name as their first element, but this is not the case for the ' - 'local at position {} in the sequence: {}.'.format(index, element) - ) - name = element[0] - value = element[1] - py_typecheck.check_type(value, ComputationBuildingBlock) - updated_locals.append((name, value)) - py_typecheck.check_type(result, ComputationBuildingBlock) - super().__init__(result.type_signature) - self._locals = updated_locals - self._result = result - - def _proto(self) -> pb.Computation: - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - block=pb.Block( - **{ - 'local': [ - pb.Block.Local(name=k, value=v.proto) - for k, v in self._locals - ], - 'result': self._result.proto, - } - ), - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - for _, value in self._locals: - yield value - yield self._result - - @property - def locals(self) -> list[tuple[str, ComputationBuildingBlock]]: - return list(self._locals) - - @property - def result(self) -> ComputationBuildingBlock: - return self._result - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Block): - return NotImplemented - return (self._locals, self._result) == (other._locals, other._result) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((tuple(self._locals), self._result)) - return self._cached_hash - - def __repr__(self) -> str: - return 'Block([{}], {!r})'.format( - ', '.join("('{}', {!r})".format(k, v) for k, v in self._locals), - self._result, - ) - - -class Intrinsic(ComputationBuildingBlock): - """A representation of an intrinsic in TFF's internal language. - - An instrinsic is a symbol known to the TFF's compiler pipeline, represented - as a known URI. It generally appears in expressions with a concrete type, - although all intrinsic are defined with template types. This class does not - deal with parsing intrinsic URIs and verifying their types, it is only a - container. Parsing and type analysis are a responsibility of the components - that manipulate ASTs. See intrinsic_defs.py for the list of known intrinsics. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Intrinsic': - _check_computation_oneof(computation_proto, 'intrinsic') - return cls( - computation_proto.intrinsic.uri, - type_serialization.deserialize_type(computation_proto.type), - ) - - def __init__(self, uri: str, type_signature: computation_types.Type): - """Creates an intrinsic. - - Args: - uri: The URI of the intrinsic. - type_signature: A `tff.Type`, the type of the intrinsic. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - py_typecheck.check_type(uri, str) - py_typecheck.check_type(type_signature, computation_types.Type) - intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(uri) - if intrinsic_def is not None: - # Note: this is really expensive. - type_analysis.check_concrete_instance_of( - type_signature, intrinsic_def.type_signature - ) - super().__init__(type_signature) - self._uri = uri - - def _proto(self) -> pb.Computation: - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - intrinsic=pb.Intrinsic(uri=self._uri), - ) - - def intrinsic_def(self) -> intrinsic_defs.IntrinsicDef: - intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(self._uri) - if intrinsic_def is None: - raise ValueError( - 'Failed to retrieve definition of intrinsic with URI ' - f'`{self._uri}`. Perhaps a definition needs to be added to ' - '`intrinsic_defs.py`?' - ) - return intrinsic_def - - def children(self) -> Iterator[ComputationBuildingBlock]: - del self - return iter(()) - - @property - def uri(self) -> str: - return self._uri - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Intrinsic): - return NotImplemented - return ( - self._uri, - self._type_signature, - ) == ( - other._uri, - other._type_signature, - ) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((self._uri, self._type_signature)) - return self._cached_hash - - def __repr__(self) -> str: - return "Intrinsic('{}', {!r})".format(self._uri, self.type_signature) - - -class Data(ComputationBuildingBlock): - """A representation of data (an input pipeline). - - This class does not deal with parsing data protos and verifying correctness, - it is only a container. Parsing and type analysis are a responsibility - or a component external to this module. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Data': - _check_computation_oneof(computation_proto, 'data') - return cls( - computation_proto.data.content, - type_serialization.deserialize_type(computation_proto.type), - ) - - def __init__(self, content: any_pb2.Any, type_spec: object): - """Creates a representation of data. - - Args: - content: The proto that characterizes the data. - type_spec: Either the types.Type that represents the type of this data, or - something convertible to it by types.to_type(). - - Raises: - TypeError: if the arguments are of the wrong types. - ValueError: if the user tries to specify an empty URI. - """ - if type_spec is None: - raise TypeError('Expected `type_spec` to not be `None`.') - type_spec = computation_types.to_type(type_spec) - super().__init__(type_spec) - self._content = content - - def _proto(self) -> pb.Computation: - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - data=pb.Data(content=self._content), - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - del self - return iter(()) - - @property - def content(self) -> any_pb2.Any: - return self._content - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Data): - return NotImplemented - return ( - self._content, - self._type_signature, - ) == ( - other._content, - other._type_signature, - ) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((str(self._content), self._type_signature)) - return self._cached_hash - - def __repr__(self) -> str: - return 'Data({!r}, {!r})'.format(self._content, self.type_signature) - - -class CompiledComputation(ComputationBuildingBlock): - """A representation of a fully constructed and serialized computation. - - A compiled computation is one that has not been parsed into constituents, and - is simply represented as an embedded `Computation` protocol buffer. Whereas - technically, any computation can be represented and passed around this way, - this structure is generally only used to represent TensorFlow sections, for - which otherwise there isn't any dedicated structure. - """ - - def __init__( - self, - proto: pb.Computation, - name: Optional[str] = None, - type_signature: Optional[computation_types.Type] = None, - ): - """Creates a representation of a fully constructed computation. - - Args: - proto: An instance of pb.Computation with the computation logic. - name: An optional string name to associate with this computation, used - only for debugging purposes. If the name is not specified (None), it is - autogenerated as a hexadecimal string from the hash of the proto. - type_signature: An optional type signature to associate with this - computation rather than the serialized one. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - py_typecheck.check_type(proto, pb.Computation) - if name is not None: - py_typecheck.check_type(name, str) - if type_signature is None: - type_signature = type_serialization.deserialize_type(proto.type) - py_typecheck.check_type(type_signature, computation_types.Type) - super().__init__(type_signature) - self._proto_representation = proto - if name is not None: - self._name = name - else: - self._name = '{:x}'.format( - zlib.adler32(self._proto_representation.SerializeToString()) - ) - - def _proto(self) -> pb.Computation: - return self._proto_representation - - def children(self) -> Iterator[ComputationBuildingBlock]: - del self - return iter(()) - - @property - def name(self) -> str: - return self._name - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, CompiledComputation): - return NotImplemented - return ( - self._proto_representation, - self._name, - self._type_signature, - ) == ( - other._proto_representation, - other._name, - other._type_signature, - ) - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash(( - self._proto_representation.SerializeToString(), - self._name, - self._type_signature, - )) - return self._cached_hash - - def __repr__(self) -> str: - return "CompiledComputation('{}', {!r})".format( - self._name, self.type_signature - ) - - -class Placement(ComputationBuildingBlock): - """A representation of a placement literal in TFF's internal language. - - Currently this can only be `tff.SERVER` or `tff.CLIENTS`. - """ - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Placement': - _check_computation_oneof(computation_proto, 'placement') - return cls( - placements.uri_to_placement_literal( - str(computation_proto.placement.uri) - ) - ) - - def __init__(self, literal: placements.PlacementLiteral): - """Constructs a new placement instance for the given placement literal. - - Args: - literal: The placement literal. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - py_typecheck.check_type(literal, placements.PlacementLiteral) - super().__init__(computation_types.PlacementType()) - self._literal = literal - - def _proto(self) -> pb.Computation: - return pb.Computation( - type=type_serialization.serialize_type(self.type_signature), - placement=pb.Placement(uri=self._literal.uri), - ) - - def children(self) -> Iterator[ComputationBuildingBlock]: - del self - return iter(()) - - @property - def uri(self) -> str: - return self._literal.uri - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Placement): - return NotImplemented - return self._literal == other._literal - - def __hash__(self): - if self._cached_hash is None: - self._cached_hash = hash((self._literal)) - return self._cached_hash - - def __repr__(self) -> str: - return "Placement('{}')".format(self.uri) - - -class Literal(ComputationBuildingBlock): - """A representation of a literal in TFF's internal language.""" - - def __init__( - self, value: array.Array, type_signature: computation_types.TensorType - ): - """Returns an initialized `tff.framework.Literal`. - - Args: - value: The value of the literal. - type_signature: A `tff.TensorType`. - - Raises: - ValueError: If `value` is not compatible with `type_signature`. - """ - if ( - isinstance(value, (np.ndarray, np.generic)) - and value.dtype.type is np.str_ - ): - value = value.astype(np.bytes_) - elif isinstance(value, str): - value = value.encode() - - if not array.is_compatible_dtype(value, type_signature.dtype.type): - raise ValueError( - f"Expected '{value}' to be compatible with" - f" '{type_signature.dtype.type}'." - ) - - if not array.is_compatible_shape(value, type_signature.shape): - raise ValueError( - f"Expected '{value}' to be compatible with '{type_signature.shape}'." - ) - - super().__init__(type_signature) - self._value = value - self._type_signature = type_signature - self._cached_hash = None - - @property - def type_signature(self) -> computation_types.TensorType: - return self._type_signature - - @classmethod - def from_proto(cls, computation_proto: pb.Computation) -> 'Literal': - _check_computation_oneof(computation_proto, 'literal') - value = array.from_proto(computation_proto.literal.value) - type_signature = type_serialization.deserialize_type(computation_proto.type) - if not isinstance(type_signature, computation_types.TensorType): - raise ValueError( - 'Expected `type_signature` to be a `tff.TensorType`, found' - f' {type(type_signature)}.' - ) - return cls(value, type_signature) - - def _proto(self) -> pb.Computation: - type_pb = type_serialization.serialize_type(self.type_signature) - value_pb = array.to_proto( - self._value, dtype_hint=self.type_signature.dtype.type - ) - literal_pb = pb.Literal(value=value_pb) - return pb.Computation(type=type_pb, literal=literal_pb) - - def children(self) -> Iterator[ComputationBuildingBlock]: - return iter(()) - - @property - def value(self) -> object: - return self._value - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, Literal): - return NotImplemented - - if self._type_signature != other._type_signature: - return False - if isinstance(self._value, np.ndarray) and isinstance( - other._value, np.ndarray - ): - return np.array_equal(self._value, other._value) - else: - return self._value == other._value - - def __hash__(self): - if self._cached_hash is None: - if isinstance(self._value, (np.ndarray, np.generic)): - hashable_value = tuple(self._value.flatten().tolist()) - else: - hashable_value = self._value - self._cached_hash = hash((hashable_value, self._type_signature)) - return self._cached_hash - - def __repr__(self) -> str: - if isinstance(self._value, np.ndarray): - value_repr = ( - f'np.array({self._value.tolist()},' - f' dtype=np.{self._value.dtype.type.__name__})' - ) - else: - value_repr = repr(self._value) - return f'Literal({value_repr}, {self._type_signature!r})' - - -def _string_representation( - comp: ComputationBuildingBlock, - formatted: bool, -) -> str: - """Returns the string representation of a `ComputationBuildingBlock`. - - This functions creates a `list` of strings representing the given `comp`; - combines the strings in either a formatted or un-formatted representation; and - returns the resulting string represetnation. - - Args: - comp: An instance of a `ComputationBuildingBlock`. - formatted: A boolean indicating if the returned string should be formatted. - - Raises: - TypeError: If `comp` has an unepxected type. - """ - py_typecheck.check_type(comp, ComputationBuildingBlock) - - def _join(components: Iterable[list[str]]) -> list[str]: - """Returns a `list` of strings by combining each component in `components`. - - >>> _join([['a'], ['b'], ['c']]) - ['abc'] - - >>> _join([['a', 'b', 'c'], ['d', 'e', 'f']]) - ['abcd', 'ef'] - - This function is used to help track where new-lines should be inserted into - the string representation if the lines are formatted. - - Args: - components: A `list` where each element is a `list` of strings - representing a part of the string of a `ComputationBuildingBlock`. - """ - lines = [''] - for component in components: - lines[-1] = '{}{}'.format(lines[-1], component[0]) - lines.extend(component[1:]) - return lines - - def _indent(lines, indent_chars=' '): - """Returns a `list` of strings indented across a slice.""" - return ['{}{}'.format(indent_chars, e) for e in lines] - - def _lines_for_named_comps(named_comps, formatted): - """Returns a `list` of strings representing the given `named_comps`. - - Args: - named_comps: A `list` of named computations, each being a pair consisting - of a name (either a string, or `None`) and a `ComputationBuildingBlock`. - formatted: A boolean indicating if the returned string should be - formatted. - """ - lines = [] - for index, (name, comp) in enumerate(named_comps): - if index != 0: - if formatted: - lines.append([',', '']) - else: - lines.append([',']) - element_lines = _lines_for_comp(comp, formatted) - if name is not None: - element_lines = _join([ - ['{}='.format(name)], - element_lines, - ]) - lines.append(element_lines) - return _join(lines) - - def _lines_for_comp(comp, formatted): - """Returns a `list` of strings representing the given `comp`. - - Args: - comp: An instance of a `ComputationBuildingBlock`. - formatted: A boolean indicating if the returned string should be - formatted. - """ - if isinstance(comp, Block): - lines = [] - variables_lines = _lines_for_named_comps(comp.locals, formatted) - if formatted: - variables_lines = _indent(variables_lines) - lines.extend([['(let ', ''], variables_lines, ['', ' in ']]) - else: - lines.extend([['(let '], variables_lines, [' in ']]) - result_lines = _lines_for_comp(comp.result, formatted) - lines.append(result_lines) - lines.append([')']) - return _join(lines) - elif isinstance(comp, Reference): - if comp.context is not None: - return ['{}@{}'.format(comp.name, comp.context)] - else: - return [comp.name] - elif isinstance(comp, Selection): - source_lines = _lines_for_comp(comp.source, formatted) - if comp.name is not None: - return _join([source_lines, ['.{}'.format(comp.name)]]) - else: - return _join([source_lines, ['[{}]'.format(comp.index)]]) - elif isinstance(comp, Call): - function_lines = _lines_for_comp(comp.function, formatted) - if comp.argument is not None: - argument_lines = _lines_for_comp(comp.argument, formatted) - return _join([function_lines, ['('], argument_lines, [')']]) - else: - return _join([function_lines, ['()']]) - elif isinstance(comp, CompiledComputation): - return ['comp#{}'.format(comp.name)] - elif isinstance(comp, Data): - return [str(id(comp.content))] - elif isinstance(comp, Intrinsic): - return [comp.uri] - elif isinstance(comp, Lambda): - result_lines = _lines_for_comp(comp.result, formatted) - if comp.parameter_type is None: - param_name = '' - else: - param_name = comp.parameter_name - lines = [['({} -> '.format(param_name)], result_lines, [')']] - return _join(lines) - elif isinstance(comp, Placement): - placement_literal = placements.uri_to_placement_literal(comp.uri) - return [placement_literal.name] - elif isinstance(comp, Literal): - return [str(comp.value)] - elif isinstance(comp, Struct): - if not comp: - return ['<>'] - elements = structure.to_elements(comp) - elements_lines = _lines_for_named_comps(elements, formatted) - if formatted: - elements_lines = _indent(elements_lines) - lines = [['<', ''], elements_lines, ['', '>']] - else: - lines = [['<'], elements_lines, ['>']] - return _join(lines) - else: - raise NotImplementedError('Unexpected type found: {}.'.format(type(comp))) - - lines = _lines_for_comp(comp, formatted) - lines = [line.rstrip() for line in lines] - if formatted: - return '\n'.join(lines) - else: - return ''.join(lines) - - -def _structural_representation(comp): - """Returns the structural string representation of the given `comp`. - - This functions creates and returns a string representing the structure of the - abstract syntax tree for the given `comp`. - - Args: - comp: An instance of a `ComputationBuildingBlock`. - - Raises: - TypeError: If `comp` has an unepxected type. - """ - py_typecheck.check_type(comp, ComputationBuildingBlock) - padding_char = ' ' - - def _get_leading_padding(string): - """Returns the length of the leading padding for the given `string`.""" - for index, character in enumerate(string): - if character != padding_char: - return index - return len(string) - - def _get_trailing_padding(string): - """Returns the length of the trailing padding for the given `string`.""" - for index, character in enumerate(reversed(string)): - if character != padding_char: - return index - return len(string) - - def _pad_left(lines, total_width): - """Pads the beginning of each line in `lines` to the given `total_width`. - - >>>_pad_left(['aa', 'bb'], 4) - [' aa', ' bb',] - - Args: - lines: A `list` of strings to pad. - total_width: The length that each line in `lines` should be padded to. - - Returns: - A `list` of lines with padding applied. - """ - - def _pad_line_left(line, total_width): - current_width = len(line) - assert current_width <= total_width - padding = total_width - current_width - return '{}{}'.format(padding_char * padding, line) - - return [_pad_line_left(line, total_width) for line in lines] - - def _pad_right(lines, total_width): - """Pads the end of each line in `lines` to the given `total_width`. - - >>>_pad_right(['aa', 'bb'], 4) - ['aa ', 'bb '] - - Args: - lines: A `list` of strings to pad. - total_width: The length that each line in `lines` should be padded to. - - Returns: - A `list` of lines with padding applied. - """ - - def _pad_line_right(line, total_width): - current_width = len(line) - assert current_width <= total_width - padding = total_width - current_width - return '{}{}'.format(line, padding_char * padding) - - return [_pad_line_right(line, total_width) for line in lines] - - class Alignment(enum.Enum): - LEFT = 1 - RIGHT = 2 - - def _concatenate(lines_1, lines_2, align): - """Concatenates two `list`s of strings. - - Concatenates two `list`s of strings by appending one list of strings to the - other and then aligning lines of different widths by either padding the left - or padding the right of each line to the width of the longest line. - - >>>_concatenate(['aa', 'bb'], ['ccc'], Alignment.LEFT) - ['aa ', 'bb ', 'ccc'] - - Args: - lines_1: A `list` of strings. - lines_2: A `list` of strings. - align: An enum indicating how to align lines of different widths. - - Returns: - A `list` of lines. - """ - lines = lines_1 + lines_2 - longest_line = max(lines, key=len) - longest_width = len(longest_line) - if align is Alignment.LEFT: - return _pad_right(lines, longest_width) - elif align is Alignment.RIGHT: - return _pad_left(lines, longest_width) - - def _calculate_inset_from_padding( - left, right, preferred_padding, minimum_content_padding - ): - """Calculates the inset for the given padding. - - Note: This function is intended to only be called from `_fit_with_padding`. - - Args: - left: A `list` of strings. - right: A `list` of strings. - preferred_padding: The preferred amount of non-negative padding between - the lines in the fitted `list` of strings. - minimum_content_padding: The minimum amount of non-negative padding - allowed between the lines in the fitted `list` of strings. - - Returns: - An integer. - """ - assert preferred_padding >= 0 - assert minimum_content_padding >= 0 - - trailing_padding = _get_trailing_padding(left[0]) - leading_padding = _get_leading_padding(right[0]) - inset = trailing_padding + leading_padding - preferred_padding - for left_line, right_line in zip(left[1:], right[1:]): - trailing_padding = _get_trailing_padding(left_line) - leading_padding = _get_leading_padding(right_line) - minimum_inset = ( - trailing_padding + leading_padding - minimum_content_padding - ) - inset = min(inset, minimum_inset) - return inset - - def _fit_with_inset(left, right, inset): - r"""Concatenates the lines of two `list`s of strings. - - Note: This function is intended to only be called from `_fit_with_padding`. - - Args: - left: A `list` of strings. - right: A `list` of strings. - inset: The amount of padding to remove or add when concatenating the - lines. - - Returns: - A `list` of lines. - """ - lines = [] - for left_line, right_line in zip(left, right): - if inset > 0: - left_inset = 0 - trailing_padding = _get_trailing_padding(left_line) - if trailing_padding > 0: - left_inset = min(trailing_padding, inset) - left_line = left_line[:-left_inset] - if inset - left_inset > 0: - leading_padding = _get_leading_padding(right_line) - if leading_padding > 0: - right_inset = min(leading_padding, inset - left_inset) - right_line = right_line[right_inset:] - padding = abs(inset) if inset < 0 else 0 - line = ''.join([left_line, padding_char * padding, right_line]) - lines.append(line) - left_height = len(left) - right_height = len(right) - if left_height > right_height: - lines.extend(left[right_height:]) - elif right_height > left_height: - lines.extend(right[left_height:]) - longest_line = max(lines, key=len) - longest_width = len(longest_line) - shortest_line = min(lines, key=len) - shortest_width = len(shortest_line) - if shortest_width != longest_width: - if left_height > right_height: - lines = _pad_right(lines, longest_width) - else: - lines = _pad_left(lines, longest_width) - return lines - - def _fit_with_padding( - left, right, preferred_padding, minimum_content_padding=4 - ): - r"""Concatenates the lines of two `list`s of strings. - - Concatenates the lines of two `list`s of strings by appending each line - together using a padding. The same padding is used to append each line and - the padding is calculated starting from the `preferred_padding` without - going below `minimum_content_padding` on any of the lines. If the two - `list`s of strings have different lengths, padding will be applied to - maintain the length of each string in the resulting `list` of strings. - - >>>_fit_with_padding(['aa', 'bb'], ['ccc']) - ['aa cccc', 'bb '] - - >>>_fit_with_padding(['aa ', 'bb '], [' ccc']) - ['aa cccc', 'bb '] - - Args: - left: A `list` of strings. - right: A `list` of strings. - preferred_padding: The preferred amount of non-negative padding between - the lines in the fitted `list` of strings. - minimum_content_padding: The minimum amount of non-negative padding - allowed between the lines in the fitted `list` of strings. - - Returns: - A `list` of lines. - """ - inset = _calculate_inset_from_padding( - left, right, preferred_padding, minimum_content_padding - ) - return _fit_with_inset(left, right, inset) - - def _get_node_label(comp): - """Returns a string for node in the structure of the given `comp`.""" - if isinstance(comp, Block): - return 'Block' - elif isinstance(comp, Call): - return 'Call' - elif isinstance(comp, CompiledComputation): - return 'Compiled({})'.format(comp.name) - elif isinstance(comp, Data): - return f'Data({id(comp.content)})' - elif isinstance(comp, Intrinsic): - return comp.uri - elif isinstance(comp, Lambda): - return 'Lambda({})'.format(comp.parameter_name) - elif isinstance(comp, Reference): - return 'Ref({})'.format(comp.name) - elif isinstance(comp, Placement): - return 'Placement' - elif isinstance(comp, Selection): - key = comp.name if comp.name is not None else comp.index - return 'Sel({})'.format(key) - elif isinstance(comp, Struct): - return 'Struct' - elif isinstance(comp, Literal): - return f'Lit({comp.value})' - else: - raise TypeError('Unexpected type found: {}.'.format(type(comp))) - - def _lines_for_named_comps(named_comps): - """Returns a `list` of strings representing the given `named_comps`. - - Args: - named_comps: A `list` of named computations, each being a pair consisting - of a name (either a string, or `None`) and a `ComputationBuildingBlock`. - """ - lines = ['['] - for index, (name, comp) in enumerate(named_comps): - comp_lines = _lines_for_comp(comp) - if name is not None: - label = '{}='.format(name) - comp_lines = _fit_with_padding([label], comp_lines, 0, 0) - if index == 0: - lines = _fit_with_padding(lines, comp_lines, 0, 0) - else: - lines = _fit_with_padding(lines, [','], 0, 0) - lines = _fit_with_padding(lines, comp_lines, 1) - lines = _fit_with_padding(lines, [']'], 0, 0) - return lines - - def _lines_for_comp(comp): - """Returns a `list` of strings representing the given `comp`. - - Args: - comp: An instance of a `ComputationBuildingBlock`. - """ - node_label = _get_node_label(comp) - - if isinstance( - comp, - ( - CompiledComputation, - Data, - Intrinsic, - Placement, - Reference, - Literal, - ), - ): - return [node_label] - elif isinstance(comp, Block): - variables_lines = _lines_for_named_comps(comp.locals) - variables_width = len(variables_lines[0]) - variables_trailing_padding = _get_trailing_padding(variables_lines[0]) - leading_padding = variables_width - variables_trailing_padding - edge_line = '{}/'.format(padding_char * leading_padding) - variables_lines = _concatenate( - [edge_line], variables_lines, Alignment.LEFT - ) - - result_lines = _lines_for_comp(comp.result) - result_width = len(result_lines[0]) - leading_padding = _get_leading_padding(result_lines[0]) - 1 - trailing_padding = result_width - leading_padding - 1 - edge_line = '\\{}'.format(padding_char * trailing_padding) - result_lines = _concatenate([edge_line], result_lines, Alignment.RIGHT) - - preferred_padding = len(node_label) - lines = _fit_with_padding( - variables_lines, result_lines, preferred_padding - ) - leading_padding = _get_leading_padding(lines[0]) + 1 - node_line = '{}{}'.format(padding_char * leading_padding, node_label) - return _concatenate([node_line], lines, Alignment.LEFT) - elif isinstance(comp, Call): - function_lines = _lines_for_comp(comp.function) - function_width = len(function_lines[0]) - function_trailing_padding = _get_trailing_padding(function_lines[0]) - leading_padding = function_width - function_trailing_padding - edge_line = '{}/'.format(padding_char * leading_padding) - function_lines = _concatenate([edge_line], function_lines, Alignment.LEFT) - - if comp.argument is not None: - argument_lines = _lines_for_comp(comp.argument) - argument_width = len(argument_lines[0]) - leading_padding = _get_leading_padding(argument_lines[0]) - 1 - trailing_padding = argument_width - leading_padding - 1 - edge_line = '\\{}'.format(padding_char * trailing_padding) - argument_lines = _concatenate( - [edge_line], argument_lines, Alignment.RIGHT - ) - - preferred_padding = len(node_label) - lines = _fit_with_padding( - function_lines, argument_lines, preferred_padding - ) - else: - lines = function_lines - leading_padding = _get_leading_padding(lines[0]) + 1 - node_line = '{}{}'.format(padding_char * leading_padding, node_label) - return _concatenate([node_line], lines, Alignment.LEFT) - elif isinstance(comp, Lambda): - result_lines = _lines_for_comp(comp.result) - leading_padding = _get_leading_padding(result_lines[0]) - node_line = '{}{}'.format(padding_char * leading_padding, node_label) - edge_line = '{}|'.format(padding_char * leading_padding) - return _concatenate([node_line, edge_line], result_lines, Alignment.LEFT) - elif isinstance(comp, Selection): - source_lines = _lines_for_comp(comp.source) - leading_padding = _get_leading_padding(source_lines[0]) - node_line = '{}{}'.format(padding_char * leading_padding, node_label) - edge_line = '{}|'.format(padding_char * leading_padding) - return _concatenate([node_line, edge_line], source_lines, Alignment.LEFT) - elif isinstance(comp, Struct): - elements = structure.to_elements(comp) - elements_lines = _lines_for_named_comps(elements) - leading_padding = _get_leading_padding(elements_lines[0]) - node_line = '{}{}'.format(padding_char * leading_padding, node_label) - edge_line = '{}|'.format(padding_char * leading_padding) - return _concatenate( - [node_line, edge_line], elements_lines, Alignment.LEFT - ) - else: - raise NotImplementedError('Unexpected type found: {}.'.format(type(comp))) - - lines = _lines_for_comp(comp) - lines = [line.rstrip() for line in lines] - return '\n'.join(lines) - - -_deserializer_dict = { - 'reference': Reference.from_proto, - 'selection': Selection.from_proto, - 'struct': Struct.from_proto, - 'call': Call.from_proto, - 'lambda': Lambda.from_proto, - 'block': Block.from_proto, - 'intrinsic': Intrinsic.from_proto, - 'data': Data.from_proto, - 'placement': Placement.from_proto, - 'literal': Literal.from_proto, - 'tensorflow': CompiledComputation, - 'xla': CompiledComputation, -} diff --git a/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py b/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py deleted file mode 100644 index c874bafeff..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py +++ /dev/null @@ -1,2637 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -import ml_dtypes -import numpy as np -import tree - -from google.protobuf import any_pb2 -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.proto.v0 import computation_pb2 -from tensorflow_federated.proto.v0 import data_type_pb2 -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import array -from tensorflow_federated.python.core.impl.compiler import building_block_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_serialization - - -_TEST_LITERAL = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) -) - - -def _to_python(obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - else: - return obj - - -class ComputationBuildingBlocksTest(absltest.TestCase): - - def test_basic_functionality_of_reference_class(self): - x = building_blocks.Reference('foo', np.int32) - self.assertEqual(x.name, 'foo') - self.assertEqual(str(x.type_signature), 'int32') - self.assertEqual(repr(x), "Reference('foo', TensorType(np.int32))") - self.assertEqual(x.compact_representation(), 'foo') - x_proto = x.proto - self.assertEqual( - type_serialization.deserialize_type(x_proto.type), x.type_signature - ) - self.assertEqual(x_proto.WhichOneof('computation'), 'reference') - self.assertEqual(x_proto.reference.name, x.name) - self._serialize_deserialize_roundtrip_test(x) - - def test_reference_children_is_empty(self): - ref = building_blocks.Reference('foo', np.int32) - self.assertEqual([], list(ref.children())) - - def test_basic_functionality_of_selection_class(self): - x = building_blocks.Reference( - 'foo', [('bar', np.int32), ('baz', np.float64)] - ) - y = building_blocks.Selection(x, name='bar') - self.assertEqual(y.name, 'bar') - self.assertIsNone(y.index) - self.assertEqual(str(y.type_signature), 'int32') - self.assertEqual( - repr(y), - ( - "Selection(Reference('foo', StructType([" - "('bar', TensorType(np.int32)), ('baz', TensorType(np.float64))]))" - ", name='bar')" - ), - ) - self.assertEqual(y.compact_representation(), 'foo.bar') - z = building_blocks.Selection(x, name='baz') - self.assertEqual(str(z.type_signature), 'float64') - self.assertEqual(z.compact_representation(), 'foo.baz') - with self.assertRaises(ValueError): - _ = building_blocks.Selection(x, name='bak') - x0 = building_blocks.Selection(x, index=0) - self.assertIsNone(x0.name) - self.assertEqual(x0.index, 0) - self.assertEqual(str(x0.type_signature), 'int32') - self.assertEqual( - repr(x0), - ( - "Selection(Reference('foo', StructType([" - "('bar', TensorType(np.int32)), ('baz', TensorType(np.float64))]))" - ', index=0)' - ), - ) - self.assertEqual(x0.compact_representation(), 'foo[0]') - x1 = building_blocks.Selection(x, index=1) - self.assertEqual(str(x1.type_signature), 'float64') - self.assertEqual(x1.compact_representation(), 'foo[1]') - with self.assertRaises(ValueError): - _ = building_blocks.Selection(x, index=2) - with self.assertRaises(ValueError): - _ = building_blocks.Selection(x, index=-1) - y_proto = y.proto - self.assertEqual( - type_serialization.deserialize_type(y_proto.type), y.type_signature - ) - self.assertEqual(y_proto.WhichOneof('computation'), 'selection') - self.assertEqual(str(y_proto.selection.source), str(x.proto)) - # Our serialized representation only uses indices. - self.assertEqual(y_proto.selection.index, 0) - self._serialize_deserialize_roundtrip_test(y) - self._serialize_deserialize_roundtrip_test(z) - self._serialize_deserialize_roundtrip_test(x0) - self._serialize_deserialize_roundtrip_test(x1) - - def test_reference_children_yields_source(self): - source = building_blocks.Reference('foo', (np.int32, np.int32)) - selection = building_blocks.Selection(source, index=1) - self.assertEqual([source], list(selection.children())) - - def test_basic_functionality_of_struct_class(self): - x = building_blocks.Reference('foo', np.int32) - y = building_blocks.Reference('bar', np.float64) - z = building_blocks.Struct([x, ('y', y)]) - with self.assertRaises(ValueError): - _ = building_blocks.Struct([('', y)]) - self.assertIsInstance(z, structure.Struct) - self.assertEqual(str(z.type_signature), '') - self.assertEqual( - repr(z), - ( - "Struct([(None, Reference('foo', TensorType(np.int32))), ('y', " - "Reference('bar', TensorType(np.float64)))])" - ), - ) - self.assertEqual(z.compact_representation(), '') - self.assertEqual(dir(z), ['y']) - self.assertIs(z.y, y) - self.assertLen(z, 2) - self.assertIs(z[0], x) - self.assertIs(z[1], y) - self.assertEqual( - ','.join(e.compact_representation() for e in iter(z)), 'foo,bar' - ) - z_proto = z.proto - self.assertEqual( - type_serialization.deserialize_type(z_proto.type), z.type_signature - ) - self.assertEqual(z_proto.WhichOneof('computation'), 'struct') - self.assertEqual([e.name for e in z_proto.struct.element], ['', 'y']) - self._serialize_deserialize_roundtrip_test(z) - - def test_struct_children_yields_elements(self): - e1 = building_blocks.Reference('a', np.int32) - e2 = building_blocks.Reference('b', np.int32) - struct_ = building_blocks.Struct([(None, e1), (None, e2)]) - self.assertEqual([e1, e2], list(struct_.children())) - - def test_struct_with_container_type(self): - x = building_blocks.Reference('foo', np.int32) - y = building_blocks.Reference('bar', np.float64) - z = building_blocks.Struct([x, ('y', y)], tuple) - self.assertEqual( - z.type_signature, - computation_types.StructWithPythonType( - [np.int32, ('y', np.float64)], tuple - ), - ) - - def test_basic_functionality_of_call_class(self): - x = building_blocks.Reference( - 'foo', computation_types.FunctionType(np.int32, np.float64) - ) - y = building_blocks.Reference('bar', np.int32) - z = building_blocks.Call(x, y) - self.assertEqual(str(z.type_signature), 'float64') - self.assertIs(z.function, x) - self.assertIs(z.argument, y) - self.assertEqual( - repr(z), - ( - "Call(Reference('foo', " - 'FunctionType(TensorType(np.int32), TensorType(np.float64))), ' - "Reference('bar', TensorType(np.int32)))" - ), - ) - self.assertEqual(z.compact_representation(), 'foo(bar)') - with self.assertRaises(TypeError): - building_blocks.Call(x) - w = building_blocks.Reference('bak', np.float32) - with self.assertRaises(TypeError): - building_blocks.Call(x, w) - z_proto = z.proto - self.assertEqual( - type_serialization.deserialize_type(z_proto.type), z.type_signature - ) - self.assertEqual(z_proto.WhichOneof('computation'), 'call') - self.assertEqual(str(z_proto.call.function), str(x.proto)) - self.assertEqual(str(z_proto.call.argument), str(y.proto)) - self._serialize_deserialize_roundtrip_test(z) - - def test_call_children_with_no_arg_yields_function(self): - fn = building_blocks.Reference( - 'a', computation_types.FunctionType(None, np.int32) - ) - call = building_blocks.Call(fn) - self.assertEqual([fn], list(call.children())) - - def test_call_children_with_arg_yields_function_and_arg(self): - fn = building_blocks.Reference( - 'a', computation_types.FunctionType(np.int32, np.int32) - ) - arg = building_blocks.Reference('b', np.int32) - call = building_blocks.Call(fn, arg) - self.assertEqual([fn, arg], list(call.children())) - - def test_basic_functionality_of_lambda_class(self): - arg_name = 'arg' - arg_type = [ - ('f', computation_types.FunctionType(np.int32, np.int32)), - ('x', np.int32), - ] - arg = building_blocks.Reference(arg_name, arg_type) - arg_f = building_blocks.Selection(arg, name='f') - arg_x = building_blocks.Selection(arg, name='x') - x = building_blocks.Lambda( - arg_name, - arg_type, - building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x)), - ) - self.assertEqual( - str(x.type_signature), '( int32),x=int32> -> int32)' - ) - self.assertEqual(x.parameter_name, arg_name) - self.assertEqual(str(x.parameter_type), ' int32),x=int32>') - self.assertEqual(x.result.compact_representation(), 'arg.f(arg.f(arg.x))') - arg_type_repr = ( - 'StructType([' - "('f', FunctionType(TensorType(np.int32), TensorType(np.int32))), " - "('x', TensorType(np.int32))])" - ) - self.assertEqual( - repr(x), - "Lambda('arg', {0}, " - "Call(Selection(Reference('arg', {0}), name='f'), " - "Call(Selection(Reference('arg', {0}), name='f'), " - "Selection(Reference('arg', {0}), name='x'))))".format(arg_type_repr), - ) - self.assertEqual(x.compact_representation(), '(arg -> arg.f(arg.f(arg.x)))') - x_proto = x.proto - self.assertEqual( - type_serialization.deserialize_type(x_proto.type), x.type_signature - ) - self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') - self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) - self.assertEqual( - str(getattr(x_proto, 'lambda').result), str(x.result.proto) - ) - self._serialize_deserialize_roundtrip_test(x) - - def test_lambda_children_returns_result(self): - result = building_blocks.Reference('a', np.int32) - lambda_ = building_blocks.Lambda('a', np.int32, result) - self.assertEqual([result], list(lambda_.children())) - - def test_basic_functionality_of_block_class(self): - x = building_blocks.Block( - [ - ('x', building_blocks.Reference('arg', (np.int32, np.int32))), - ( - 'y', - building_blocks.Selection( - building_blocks.Reference('x', (np.int32, np.int32)), - index=0, - ), - ), - ], - building_blocks.Reference('y', np.int32), - ) - self.assertEqual(str(x.type_signature), 'int32') - self.assertEqual( - [(k, v.compact_representation()) for k, v in x.locals], - [('x', 'arg'), ('y', 'x[0]')], - ) - self.assertEqual(x.result.compact_representation(), 'y') - self.assertEqual( - repr(x), - ( - "Block([('x', Reference('arg', StructType([TensorType(np.int32)," - " TensorType(np.int32)]) as tuple)), ('y', Selection(Reference('x'," - ' StructType([TensorType(np.int32), TensorType(np.int32)]) as' - " tuple), index=0))], Reference('y', TensorType(np.int32)))" - ), - ) - self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)') - x_proto = x.proto - self.assertEqual( - type_serialization.deserialize_type(x_proto.type), x.type_signature - ) - self.assertEqual(x_proto.WhichOneof('computation'), 'block') - self.assertEqual(str(x_proto.block.result), str(x.result.proto)) - for idx, loc_proto in enumerate(x_proto.block.local): - loc_name, loc_value = x.locals[idx] - self.assertEqual(loc_proto.name, loc_name) - self.assertEqual(str(loc_proto.value), str(loc_value.proto)) - self._serialize_deserialize_roundtrip_test(x) - - def test_block_children_returns_locals_then_result(self): - l1 = building_blocks.Reference('a', np.int32) - l2 = building_blocks.Reference('b', np.int32) - result = building_blocks.Reference('c', np.int32) - block = building_blocks.Block([('1', l1), ('2', l2)], result) - self.assertEqual([l1, l2, result], list(block.children())) - - def test_basic_functionality_of_intrinsic_class(self): - x = building_blocks.Intrinsic( - 'add_one', computation_types.FunctionType(np.int32, np.int32) - ) - self.assertEqual(str(x.type_signature), '(int32 -> int32)') - self.assertEqual(x.uri, 'add_one') - self.assertEqual( - repr(x), - ( - "Intrinsic('add_one', " - 'FunctionType(TensorType(np.int32), TensorType(np.int32)))' - ), - ) - self.assertEqual(x.compact_representation(), 'add_one') - x_proto = x.proto - self.assertEqual( - type_serialization.deserialize_type(x_proto.type), x.type_signature - ) - self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') - self.assertEqual(x_proto.intrinsic.uri, x.uri) - self._serialize_deserialize_roundtrip_test(x) - - def test_intrinsic_children_is_empty(self): - intrinsic = building_blocks.Intrinsic( - 'a', computation_types.FunctionType(np.int32, np.int32) - ) - self.assertEqual([], list(intrinsic.children())) - - def test_basic_intrinsic_functionality_plus_canonical_typecheck(self): - x = building_blocks.Intrinsic( - 'generic_plus', - computation_types.FunctionType([np.int32, np.int32], np.int32), - ) - self.assertEqual(str(x.type_signature), '( -> int32)') - self.assertEqual(x.uri, 'generic_plus') - self.assertEqual(x.compact_representation(), 'generic_plus') - x_proto = x.proto - deserialized_type = type_serialization.deserialize_type(x_proto.type) - x.type_signature.check_assignable_from(deserialized_type) - self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') - self.assertEqual(x_proto.intrinsic.uri, x.uri) - self._serialize_deserialize_roundtrip_test(x) - - def test_intrinsic_class_fails_bad_type(self): - with self.assertRaises(TypeError): - _ = building_blocks.Intrinsic( - intrinsic_defs.GENERIC_PLUS.uri, - computation_types.FunctionType([np.int32, np.int32], np.float32), - ) - - def test_intrinsic_class_fails_struct_type_with_names(self): - with self.assertRaises(TypeError): - _ = building_blocks.Intrinsic( - intrinsic_defs.GENERIC_PLUS.uri, - computation_types.FunctionType( - [('a', np.int32), ('b', np.int32)], np.int32 - ), - ) - - def test_intrinsic_class_succeeds_simple_federated_map(self): - simple_function = computation_types.FunctionType(np.int32, np.float32) - federated_arg = computation_types.FederatedType( - simple_function.parameter, placements.CLIENTS - ) - federated_result = computation_types.FederatedType( - simple_function.result, placements.CLIENTS - ) - federated_map_concrete_type = computation_types.FunctionType( - computation_types.StructType((simple_function, federated_arg)), - federated_result, - ) - concrete_federated_map = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type - ) - self.assertIsInstance(concrete_federated_map, building_blocks.Intrinsic) - self.assertEqual( - str(concrete_federated_map.type_signature), - '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)', - ) - self.assertEqual(concrete_federated_map.uri, 'federated_map') - self.assertEqual( - concrete_federated_map.compact_representation(), 'federated_map' - ) - concrete_federated_map_proto = concrete_federated_map.proto - self.assertEqual( - type_serialization.deserialize_type(concrete_federated_map_proto.type), - concrete_federated_map.type_signature, - ) - self.assertEqual( - concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic' - ) - self.assertEqual( - concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri - ) - self._serialize_deserialize_roundtrip_test(concrete_federated_map) - - def test_basic_functionality_of_data_class(self): - test_proto = array.to_proto(np.array([1, 2, 3], np.int32)) - any_proto = any_pb2.Any() - any_proto.Pack(test_proto) - x = building_blocks.Data( - any_proto, computation_types.SequenceType(np.int32) - ) - self.assertEqual(str(x.type_signature), 'int32*') - self.assertEqual(x.content, any_proto) - arr = array_pb2.Array() - x.content.Unpack(arr) - self.assertEqual(arr, test_proto) - as_string = str(any_proto) - self.assertEqual( - repr(x), f'Data({as_string}, SequenceType(TensorType(np.int32)))' - ) - as_id_string = str(id(any_proto)) - self.assertEqual( - x.compact_representation(), - as_id_string, - ) - x_proto = x.proto - self.assertEqual( - type_serialization.deserialize_type(x_proto.type), x.type_signature - ) - self.assertEqual(x_proto.WhichOneof('computation'), 'data') - self.assertEqual(x_proto.data.content, x.content) - self._serialize_deserialize_roundtrip_test(x) - - def test_data_children_is_empty(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array(1, np.int32) - ) - data = building_blocks.Data(any_proto, np.int32) - self.assertEqual([], list(data.children())) - - def test_basic_functionality_of_compiled_computation_class(self): - type_spec = computation_types.TensorType(np.int32) - x_proto = computation_factory.create_lambda_identity(type_spec) - x_type = computation_types.FunctionType(type_spec, type_spec) - x = building_blocks.CompiledComputation( - x_proto, name='a', type_signature=x_type - ) - self.assertEqual( - x.type_signature.compact_representation(), '(int32 -> int32)' - ) - self.assertIsInstance(x.proto, computation_pb2.Computation) - self.assertEqual(x.name, 'a') - self.assertTrue( - repr(x), - ( - "CompiledComputation('a', FunctionType(TensorType(np.int32)," - ' TensorType(np.int32)))' - ), - ) - self.assertTrue(x.compact_representation(), 'comp#a') - y_proto = computation_factory.create_lambda_identity(type_spec) - y_type = computation_types.FunctionType(type_spec, type_spec) - y = building_blocks.CompiledComputation( - y_proto, name='a', type_signature=y_type - ) - self._serialize_deserialize_roundtrip_test(y) - - def test_compiled_computation_children_is_empty(self): - comp_type = computation_types.TensorType(np.int32) - proto = computation_factory.create_lambda_identity(comp_type) - comp = building_blocks.CompiledComputation( - proto, name='a', type_signature=comp_type - ) - self.assertEqual([], list(comp.children())) - - def test_basic_functionality_of_placement_class(self): - x = building_blocks.Placement(placements.CLIENTS) - self.assertEqual(str(x.type_signature), 'placement') - self.assertEqual(x.uri, 'clients') - self.assertEqual(repr(x), "Placement('clients')") - self.assertEqual(x.compact_representation(), 'CLIENTS') - x_proto = x.proto - self.assertEqual( - type_serialization.deserialize_type(x_proto.type), x.type_signature - ) - self.assertEqual(x_proto.WhichOneof('computation'), 'placement') - self.assertEqual(x_proto.placement.uri, x.uri) - self._serialize_deserialize_roundtrip_test(x) - - def test_placement_children_is_empty(self): - placement = building_blocks.Placement(placements.CLIENTS) - self.assertEqual([], list(placement.children())) - - def _serialize_deserialize_roundtrip_test(self, target): - """Performs roundtrip serialization/deserialization of the given target. - - Args: - target: An instane of ComputationBuildingBlock to serialize-deserialize. - """ - self.assertIsInstance(target, building_blocks.ComputationBuildingBlock) - serialized = target.proto - deserialized = building_blocks.ComputationBuildingBlock.from_proto( - serialized - ) - reserialized = deserialized.proto - self.assertEqual(str(serialized), str(reserialized)) - # Note: This is not an equality comparison because ser/de is not an identity - # transform: it will drop the container from `StructWithPythonType`. - target.type_signature.check_assignable_from(deserialized.type_signature) - - -class ReferenceTest(parameterized.TestCase): - - def test_eq_returns_true(self): - type_signature = computation_types.TensorType(np.int32) - reference = building_blocks.Reference('reference', type_signature) - - self.assertIs(reference, reference) - self.assertEqual(reference, reference) - - @parameterized.named_parameters( - ( - 'same_name_and_type_signature', - building_blocks.Reference( - 'reference', computation_types.TensorType(np.int32) - ), - building_blocks.Reference( - 'reference', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_name', - building_blocks.Reference( - 'reference', computation_types.TensorType(np.int32) - ), - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_type_signature', - building_blocks.Reference( - 'reference', computation_types.TensorType(np.int32) - ), - building_blocks.Reference( - 'reference', computation_types.TensorType(np.float32) - ), - ), - ) - def test_eq_returns_false(self, reference, other): - self.assertIsNot(reference, other) - self.assertNotEqual(reference, other) - - def test_hash_returns_same_value(self): - type_signature = computation_types.TensorType(np.int32) - reference = building_blocks.Reference('reference', type_signature) - other = building_blocks.Reference('reference', type_signature) - - self.assertEqual(hash(reference), hash(other)) - - @parameterized.named_parameters( - ( - 'different_name', - building_blocks.Reference( - 'reference', computation_types.TensorType(np.int32) - ), - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_type_signature', - building_blocks.Reference( - 'reference', computation_types.TensorType(np.int32) - ), - building_blocks.Reference( - 'reference', computation_types.TensorType(np.float32) - ), - ), - ) - def test_hash_returns_different_value(self, reference, other): - self.assertNotEqual(reference, other) - self.assertNotEqual(hash(reference), hash(other)) - - -class SelectionTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'name', - building_blocks.Selection( - building_blocks.Struct([('x', _TEST_LITERAL)]), - name='x', - ), - building_blocks.Selection( - building_blocks.Struct([('x', _TEST_LITERAL)]), - name='x', - ), - ), - ( - 'index', - building_blocks.Selection( - building_blocks.Struct([_TEST_LITERAL]), - index=0, - ), - building_blocks.Selection( - building_blocks.Struct([_TEST_LITERAL]), - index=0, - ), - ), - ) - def test_eq_returns_true(self, selection, other): - self.assertIsNot(selection, other) - self.assertEqual(selection, other) - - @parameterized.named_parameters( - ( - 'different_source', - building_blocks.Selection( - building_blocks.Struct([_TEST_LITERAL]), - index=0, - ), - building_blocks.Selection( - building_blocks.Struct([_TEST_LITERAL, _TEST_LITERAL]), - index=0, - ), - ), - ( - 'different_name', - building_blocks.Selection( - building_blocks.Struct( - [('x', _TEST_LITERAL), ('y', _TEST_LITERAL)] - ), - name='x', - ), - building_blocks.Selection( - building_blocks.Struct( - [('x', _TEST_LITERAL), ('y', _TEST_LITERAL)] - ), - name='y', - ), - ), - ( - 'different_index', - building_blocks.Selection( - building_blocks.Struct([_TEST_LITERAL, _TEST_LITERAL]), - index=0, - ), - building_blocks.Selection( - building_blocks.Struct([_TEST_LITERAL, _TEST_LITERAL]), - index=1, - ), - ), - ) - def test_eq_returns_false(self, selection, other): - self.assertIsNot(selection, other) - self.assertNotEqual(selection, other) - - @parameterized.named_parameters( - ( - 'name', - building_blocks.Selection( - building_blocks.Struct([ - ( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ), - ]), - name='x', - ), - building_blocks.Selection( - building_blocks.Struct([ - ( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ), - ]), - name='x', - ), - ), - ( - 'index', - building_blocks.Selection( - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - index=0, - ), - building_blocks.Selection( - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - index=0, - ), - ), - ) - def test_hash_returns_same_value(self, selection, other): - self.assertEqual(hash(selection), hash(other)) - - @parameterized.named_parameters( - ( - 'different_source', - building_blocks.Selection( - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - index=0, - ), - building_blocks.Selection( - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - index=0, - ), - ), - ( - 'different_name', - building_blocks.Selection( - building_blocks.Struct([ - ( - 'x', - building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ), - ), - ( - 'y', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ), - ]), - name='x', - ), - building_blocks.Selection( - building_blocks.Struct([ - ( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ), - ( - 'y', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ), - ]), - name='y', - ), - ), - ( - 'different_index', - building_blocks.Selection( - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - index=0, - ), - building_blocks.Selection( - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - index=1, - ), - ), - ) - def test_hash_returns_different_value(self, selection, other): - self.assertNotEqual(selection, other) - self.assertNotEqual(hash(selection), hash(other)) - - -class StructTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'container_type_none', - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - ), - ( - 'container_type_list', - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - ), - ) - def test_eq_returns_true(self, struct, other): - self.assertIsNot(struct, other) - self.assertEqual(struct, other) - - @parameterized.named_parameters( - ( - 'different_elements', - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - ), - ( - 'different_container_type', - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=tuple, - ), - ), - ) - def test_eq_returns_false(self, struct, other): - self.assertIsNot(struct, other) - self.assertNotEqual(struct, other) - - @parameterized.named_parameters( - ( - 'container_type_none', - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - building_blocks.Struct([ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ]), - ), - ( - 'container_type_list', - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - ), - ) - def test_hash_returns_same_value(self, struct, other): - self.assertEqual(hash(struct), hash(other)) - - def test_hash_returns_same_value_with_different_container_type(self): - type_signature = computation_types.TensorType(np.int32) - element = building_blocks.Literal(1, type_signature) - struct = building_blocks.Struct([element], container_type=list) - other = building_blocks.Struct([element], container_type=tuple) - - self.assertEqual(hash(struct), hash(other)) - - @parameterized.named_parameters( - ( - 'different_elements', - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - building_blocks.Struct( - [ - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ], - container_type=list, - ), - ), - ) - def test_hash_returns_different_value(self, struct, other): - self.assertNotEqual(struct, other) - self.assertNotEqual(hash(struct), hash(other)) - - -class CallTest(parameterized.TestCase): - - def test_eq_returns_true(self): - type_signature = computation_types.TensorType(np.int32) - result = building_blocks.Reference('x', type_signature) - fn = building_blocks.Lambda('x', type_signature, result) - arg = building_blocks.Literal(1, type_signature) - call = building_blocks.Call(fn, arg) - other = building_blocks.Call(fn, arg) - - self.assertIsNot(call, other) - self.assertEqual(call, other) - - @parameterized.named_parameters( - ( - 'different_fn', - building_blocks.Call( - building_blocks.Lambda( - 'x', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'x', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - ), - building_blocks.Call( - building_blocks.Reference( - 'different', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'arg', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_arg', - building_blocks.Call( - building_blocks.Reference( - 'fn', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'arg', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Call( - building_blocks.Reference( - 'fn', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ), - ) - def test_eq_returns_false(self, call, other): - self.assertIsNot(call, other) - self.assertNotEqual(call, other) - - def test_hash_returns_same_value(self): - type_signature = computation_types.TensorType(np.int32) - fn = building_blocks.Reference( - 'fn', computation_types.FunctionType(type_signature, type_signature) - ) - arg = building_blocks.Reference('arg', type_signature) - call = building_blocks.Call(fn, arg) - other = building_blocks.Call(fn, arg) - - self.assertEqual(hash(call), hash(other)) - - @parameterized.named_parameters( - ( - 'different_fn', - building_blocks.Call( - building_blocks.Reference( - 'fn', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'arg', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Call( - building_blocks.Reference( - 'different', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'arg', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_arg', - building_blocks.Call( - building_blocks.Reference( - 'fn', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'arg', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Call( - building_blocks.Reference( - 'fn', - computation_types.FunctionType( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ), - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ), - ) - def test_hash_returns_different_value(self, call, other): - self.assertNotEqual(call, other) - self.assertNotEqual(hash(call), hash(other)) - - -class LambdaTest(parameterized.TestCase): - - def test_eq_returns_true(self): - type_signature = computation_types.TensorType(np.int32) - result = building_blocks.Reference('result', type_signature) - fn = building_blocks.Lambda('parameter', type_signature, result) - other = building_blocks.Lambda('parameter', type_signature, result) - - self.assertIsNot(fn, other) - self.assertEqual(fn, other) - - @parameterized.named_parameters( - ( - 'different_parameter_name', - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Lambda( - 'different', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_parameter_type', - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.float32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_result', - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ), - ) - def test_eq_returns_false(self, fn, other): - self.assertIsNot(fn, other) - self.assertNotEqual(fn, other) - - def test_hash_returns_same_value(self): - type_signature = computation_types.TensorType(np.int32) - result = building_blocks.Reference('result', type_signature) - fn = building_blocks.Lambda('parameter', type_signature, result) - other = building_blocks.Lambda('parameter', type_signature, result) - - self.assertEqual(hash(fn), hash(other)) - - @parameterized.named_parameters( - ( - 'different_parameter_name', - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Lambda( - 'different', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_parameter_type', - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.float32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_result', - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Lambda( - 'parameter', - computation_types.TensorType(np.int32), - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ), - ) - def test_hash_returns_different_value(self, fn, other): - self.assertNotEqual(fn, other) - self.assertNotEqual(hash(fn), hash(other)) - - -class BlockTest(parameterized.TestCase): - - def test_eq_returns_true(self): - type_signature = computation_types.TensorType(np.int32) - local = building_blocks.Reference('local', type_signature) - result = building_blocks.Reference('result', type_signature) - block = building_blocks.Block([('x', local)], result) - other = building_blocks.Block([('x', local)], result) - - self.assertIsNot(block, other) - self.assertEqual(block, other) - - @parameterized.named_parameters( - ( - 'different_locals', - building_blocks.Block( - [( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Block( - [( - 'different', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_result', - building_blocks.Block( - [( - 'x', - building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Block( - [( - 'x', - building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ), - ) - def test_eq_returns_false(self, block, other): - self.assertIsNot(block, other) - self.assertNotEqual(block, other) - - def test_hash_returns_same_value(self): - type_signature = computation_types.TensorType(np.int32) - local = building_blocks.Reference('local', type_signature) - result = building_blocks.Reference('result', type_signature) - block = building_blocks.Block([('x', local)], result) - other = building_blocks.Block([('x', local)], result) - - self.assertEqual(hash(block), hash(other)) - - @parameterized.named_parameters( - ( - 'different_locals', - building_blocks.Block( - [( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Block( - [( - 'different', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - ), - ( - 'different_result', - building_blocks.Block( - [( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'result', computation_types.TensorType(np.int32) - ), - ), - building_blocks.Block( - [( - 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ), - )], - building_blocks.Reference( - 'different', computation_types.TensorType(np.int32) - ), - ), - ), - ) - def test_hash_returns_different_value(self, block, other): - self.assertNotEqual(block, other) - self.assertNotEqual(hash(block), hash(other)) - - -class IntrinsicTest(parameterized.TestCase): - - def test_eq_returns_true(self): - type_signature = computation_types.TensorType(np.int32) - intrinsic = building_blocks.Intrinsic('intrinsic', type_signature) - other = building_blocks.Intrinsic('intrinsic', type_signature) - - self.assertIsNot(intrinsic, other) - self.assertEqual(intrinsic, other) - - @parameterized.named_parameters( - ( - 'different_uri', - building_blocks.Intrinsic( - 'intrinsic', computation_types.TensorType(np.int32) - ), - building_blocks.Intrinsic( - 'different', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_type_signature', - building_blocks.Intrinsic( - 'intrinsic', computation_types.TensorType(np.int32) - ), - building_blocks.Intrinsic( - 'intrinsic', computation_types.TensorType(np.float32) - ), - ), - ) - def test_eq_returns_false(self, intrinsic, other): - self.assertIsNot(intrinsic, other) - self.assertNotEqual(intrinsic, other) - - def test_hash_returns_same_value(self): - type_signature = computation_types.TensorType(np.int32) - intrinsic = building_blocks.Intrinsic('intrinsic', type_signature) - other = building_blocks.Intrinsic('intrinsic', type_signature) - - self.assertEqual(hash(intrinsic), hash(other)) - - @parameterized.named_parameters( - ( - 'different_uri', - building_blocks.Intrinsic( - 'intrinsic', computation_types.TensorType(np.int32) - ), - building_blocks.Intrinsic( - 'different', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_type_signature', - building_blocks.Intrinsic( - 'intrinsic', computation_types.TensorType(np.int32) - ), - building_blocks.Intrinsic( - 'intrinsic', computation_types.TensorType(np.float32) - ), - ), - ) - def test_hash_returns_different_value(self, intrinsic, other): - self.assertNotEqual(intrinsic, other) - self.assertNotEqual(hash(intrinsic), hash(other)) - - -class DataTest(parameterized.TestCase): - - def test_eq_returns_true(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3], np.int32) - ) - type_signature = computation_types.TensorType(np.int32) - data = building_blocks.Data(any_proto, type_signature) - other = building_blocks.Data(any_proto, type_signature) - - self.assertIsNot(data, other) - self.assertEqual(data, other) - - def test_eq_returns_false_different_content(self): - any_proto1 = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3], np.int32) - ) - type_signature = computation_types.TensorType(np.int32) - data = building_blocks.Data(any_proto1, type_signature) - - any_proto2 = building_block_test_utils.create_any_proto_from_array( - np.array([4], np.int32) - ) - other = building_blocks.Data(any_proto2, type_signature) - self.assertIsNot(data, other) - self.assertNotEqual(data, other) - - def test_eq_returns_false_different_type_signatures(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 1, 1], np.int32) - ) - type_signature1 = computation_types.TensorType(np.int32) - type_signature2 = computation_types.TensorType(np.float32) - data = building_blocks.Data(any_proto, type_signature1) - other = building_blocks.Data(any_proto, type_signature2) - - self.assertIsNot(data, other) - self.assertNotEqual(data, other) - - def test_hash_returns_same_value(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3], np.int32) - ) - type_signature = computation_types.TensorType(np.int32) - data = building_blocks.Data(any_proto, type_signature) - other = building_blocks.Data(any_proto, type_signature) - - self.assertEqual(hash(data), hash(other)) - - def test_hash_returns_different_value_for_different_content(self): - any_proto1 = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3], np.int32) - ) - type_signature = computation_types.TensorType(np.int32) - data = building_blocks.Data(any_proto1, type_signature) - - any_proto2 = building_block_test_utils.create_any_proto_from_array( - np.array([4], np.int32) - ) - other = building_blocks.Data(any_proto2, type_signature) - self.assertNotEqual(data, other) - self.assertNotEqual(hash(data), hash(other)) - - def test_hash_returns_different_value_for_different_type_signatures(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 1, 1], np.int32) - ) - type_signature1 = computation_types.TensorType(np.int32) - type_signature2 = computation_types.TensorType(np.float32) - data = building_blocks.Data(any_proto, type_signature1) - other = building_blocks.Data(any_proto, type_signature2) - - self.assertNotEqual(data, other) - self.assertNotEqual(hash(data), hash(other)) - - -class CompiledComputationTest(parameterized.TestCase): - - def test_eq_returns_true(self): - type_spec = computation_types.TensorType(np.int32) - proto = computation_factory.create_lambda_identity(type_spec) - type_signature = computation_types.FunctionType(type_spec, type_spec) - compiled = building_blocks.CompiledComputation( - proto, name='compiled', type_signature=type_signature - ) - other = building_blocks.CompiledComputation( - proto, name='compiled', type_signature=type_signature - ) - - self.assertIsNot(compiled, other) - self.assertEqual(compiled, other) - - @parameterized.named_parameters( - ( - 'different_proto', - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.float32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - ), - ( - 'different_name', - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='different', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - ), - ( - 'different_type_signature', - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType( - np.float32, np.float32 - ), - ), - ), - ) - def test_eq_returns_false(self, compiled, other): - self.assertIsNot(compiled, other) - self.assertNotEqual(compiled, other) - - def test_hash_returns_same_value(self): - type_spec = computation_types.TensorType(np.int32) - proto = computation_factory.create_lambda_identity(type_spec) - type_signature = computation_types.FunctionType(type_spec, type_spec) - compiled = building_blocks.CompiledComputation( - proto, name='compiled', type_signature=type_signature - ) - other = building_blocks.CompiledComputation( - proto, name='compiled', type_signature=type_signature - ) - - self.assertEqual(hash(compiled), hash(other)) - - @parameterized.named_parameters( - ( - 'different_proto', - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.float32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - ), - ( - 'different_name', - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='different', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - ), - ( - 'different_type_signature', - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType(np.int32, np.int32), - ), - building_blocks.CompiledComputation( - computation_factory.create_lambda_identity( - computation_types.TensorType(np.int32) - ), - name='compiled', - type_signature=computation_types.FunctionType( - np.float32, np.float32 - ), - ), - ), - ) - def test_hash_returns_different_value(self, compiled, other): - self.assertNotEqual(compiled, other) - self.assertNotEqual(hash(compiled), hash(other)) - - -class PlacementTest(parameterized.TestCase): - - def test_eq_returns_true(self): - placement = building_blocks.Placement(placements.CLIENTS) - other = building_blocks.Placement(placements.CLIENTS) - - self.assertIsNot(placement, other) - self.assertEqual(placement, other) - - @parameterized.named_parameters( - ( - 'different_literal', - building_blocks.Placement(placements.CLIENTS), - building_blocks.Placement(placements.SERVER), - ), - ) - def test_eq_returns_false(self, placement, other): - self.assertIsNot(placement, other) - self.assertNotEqual(placement, other) - - def test_hash_returns_same_value(self): - placement = building_blocks.Placement(placements.CLIENTS) - other = building_blocks.Placement(placements.CLIENTS) - - self.assertEqual(hash(placement), hash(other)) - - @parameterized.named_parameters( - ( - 'different_literal', - building_blocks.Placement(placements.CLIENTS), - building_blocks.Placement(placements.SERVER), - ), - ) - def test_hash_returns_different_value(self, placement, other): - self.assertNotEqual(placement, other) - self.assertNotEqual(hash(placement), hash(other)) - - -class LiteralTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'bool', - True, - computation_types.TensorType(np.bool_) - ), - ('int8', 1, computation_types.TensorType(np.int8)), - ('int16', 1, computation_types.TensorType(np.int16)), - ('int32', 1, computation_types.TensorType(np.int32)), - ('int64', 1, computation_types.TensorType(np.int64)), - ('uint8', 1, computation_types.TensorType(np.uint8)), - ('uint16', 1, computation_types.TensorType(np.uint16)), - ('uint32', 1, computation_types.TensorType(np.uint32)), - ('uint64', 1, computation_types.TensorType(np.uint64)), - ('float16', 1.0, computation_types.TensorType(np.float16)), - ('float32', 1.0, computation_types.TensorType(np.float32)), - ('float64', 1.0, computation_types.TensorType(np.float64)), - ('complex64', (1.0 + 1.0j), computation_types.TensorType(np.complex64)), - ('complex128', (1.0 + 1.0j), computation_types.TensorType(np.complex128)), - ( - 'bfloat16', - ml_dtypes.bfloat16(1.0), - computation_types.TensorType(ml_dtypes.bfloat16), - ), - ('str', 'a', computation_types.TensorType(np.str_)), - ('bytes', b'a', computation_types.TensorType(np.str_)), - ('generic_int', np.int32(1), computation_types.TensorType(np.int32)), - ( - 'generic_float', - np.float32(1.0), - computation_types.TensorType(np.float32), - ), - ( - 'generic_complex', - np.complex64(1.0 + 1.0j), - computation_types.TensorType(np.complex64), - ), - ( - 'generic_bool', - np.bool_(True), computation_types.TensorType(np.bool_) - ), - ('generic_str', np.str_('a'), computation_types.TensorType(np.str_)), - ('generic_bytes', np.bytes_(b'a'), computation_types.TensorType(np.str_)), - ( - 'array_bool', - np.array([True, False], np.bool_), - computation_types.TensorType(np.bool_, shape=[2]), - ), - ( - 'array_int', - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - ( - 'array_float', - np.array([1.0, 2.0, 3.0], np.float32), - computation_types.TensorType(np.float32, shape=[3]), - ), - ( - 'array_complex', - np.array([(1.0 + 1.0j), (2.0 + 1.0j), (3.0 + 1.0j)], np.complex64), - computation_types.TensorType(np.complex64, shape=[3]), - ), - ( - 'array_str', - np.array(['a', 'b', 'c'], np.str_), - computation_types.TensorType(np.str_, shape=[3]), - ), - ( - 'array_bytes', - np.array([b'a', b'b', b'c'], np.bytes_), - computation_types.TensorType(np.bytes_, shape=[3]), - ), - ( - 'array_no_dimensions', - np.array([], np.int32), - computation_types.TensorType(np.int32, shape=[0]), - ), - ( - 'array_multiple_dimensions', - np.array([[1, 2, 3], [4, 5, 6]], np.int32), - computation_types.TensorType(np.int32, shape=[2, 3]), - ), - ) - def test_init_does_not_raise_value_error(self, value, type_signature): - try: - building_blocks.Literal(value, type_signature) - except ValueError as e: - self.fail('Raised `ValueError` unexpectedly: %s', e) - - @parameterized.named_parameters( - ('str', 'a', computation_types.TensorType(np.str_), b'a'), - ( - 'generic_str', - np.str_('a'), - computation_types.TensorType(np.str_), - b'a', - ), - ( - 'array_str', - np.array(['a', 'b', 'c'], np.str_), - computation_types.TensorType(np.str_, shape=[3]), - np.array([b'a', b'b', b'c'], np.bytes_), - ), - ) - def test_init_normalizes_value_str( - self, value, type_signature, expected_value - ): - literal = building_blocks.Literal(value, type_signature) - - tree.assert_same_structure(literal.value, expected_value) - actual_value = _to_python(literal.value) - expected_value = _to_python(expected_value) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'scalar_and_incompatible_dtype_kind', - 1, - computation_types.TensorType(np.float32), - ), - ( - 'scalar_and_incompatible_dtype_size', - np.iinfo(np.int64).max, - computation_types.TensorType(np.int32), - ), - ( - 'scalar_and_incompatible_shape', - 1, - computation_types.TensorType(np.int32, shape=[2, 3]), - ), - ( - 'generic_and_incompatible_dtype_kind', - np.int32(1), - computation_types.TensorType(np.float32), - ), - ( - 'generic_and_incompatible_dtype_size', - np.int64(np.iinfo(np.int64).max), - computation_types.TensorType(np.int32), - ), - ( - 'generic_and_incompatible_shape', - np.int32(1), - computation_types.TensorType(np.int32, shape=[2, 3]), - ), - ( - 'array_and_incompatible_dtype_kind', - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.float32, shape=[3]), - ), - ( - 'array_and_incompatible_dtype_size', - np.array([np.iinfo(np.int64).max] * 3, np.int64), - computation_types.TensorType(np.int32, shape=[3]), - ), - ( - 'array_and_incompatible_shape', - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[2, 3]), - ), - ) - def test_init_raises_value_error(self, value, type_signature): - with self.assertRaises(ValueError): - building_blocks.Literal(value, type_signature) - - def test_from_proto_returns_value(self): - proto = computation_pb2.Computation( - type=computation_pb2.Type( - tensor=computation_pb2.TensorType( - dtype=data_type_pb2.DataType.DT_INT32, dims=[] - ) - ), - literal=computation_pb2.Literal( - value=array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ) - ), - ) - - actual_value = building_blocks.Literal.from_proto(proto) - - expected_value = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ) - self.assertEqual(actual_value, expected_value) - - def test_from_proto_raises_value_error_with_wrong_type(self): - proto = computation_pb2.Computation( - type=computation_pb2.Type( - federated=computation_pb2.FederatedType( - placement=computation_pb2.PlacementSpec( - value=computation_pb2.Placement(uri=placements.CLIENTS.uri) - ), - member=computation_pb2.Type( - tensor=computation_pb2.TensorType( - dtype=data_type_pb2.DataType.DT_INT32, dims=[] - ), - ), - ), - ), - literal=computation_pb2.Literal( - value=array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ) - ), - ) - - with self.assertRaises(ValueError): - building_blocks.Literal.from_proto(proto) - - def test_from_proto_raises_value_error_with_wrong_computation(self): - proto = computation_pb2.Computation( - type=computation_pb2.Type( - tensor=computation_pb2.TensorType( - dtype=data_type_pb2.DataType.DT_INT32, dims=[] - ) - ), - data=computation_pb2.Data(), - ) - - with self.assertRaises(ValueError): - building_blocks.Literal.from_proto(proto) - - def test_to_proto_returns_value(self): - literal = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - - actual_value = literal._proto() - - expected_value = computation_pb2.Computation( - type=computation_pb2.Type( - tensor=computation_pb2.TensorType( - dtype=data_type_pb2.DataType.DT_INT32, dims=[] - ) - ), - literal=computation_pb2.Literal( - value=array_pb2.Array( - dtype=data_type_pb2.DataType.DT_INT32, - shape=array_pb2.ArrayShape(dim=[]), - int32_list=array_pb2.Array.IntList(value=[1]), - ) - ), - ) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ( - 'bool', - building_blocks.Literal( - True, - computation_types.TensorType(np.bool_), - ), - ), - ( - 'int', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - ), - ( - 'float', - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float32) - ), - ), - ( - 'complex', - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex64) - ), - ), - ( - 'str', - building_blocks.Literal('a', computation_types.TensorType(np.str_)), - ), - ( - 'bytes', - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - ), - ( - 'generic', - building_blocks.Literal( - np.int32(1), computation_types.TensorType(np.int32) - ), - ), - ( - 'array', - building_blocks.Literal( - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - ), - ) - def test_children_empty(self, literal): - actual_children = list(literal.children()) - self.assertEmpty(actual_children) - - @parameterized.named_parameters( - ( - 'bool', - building_blocks.Literal( - True, - computation_types.TensorType(np.bool_) - ), - building_blocks.Literal( - True, - computation_types.TensorType(np.bool_) - ), - ), - ( - 'int8', - building_blocks.Literal(1, computation_types.TensorType(np.int8)), - building_blocks.Literal(1, computation_types.TensorType(np.int8)), - ), - ( - 'int16', - building_blocks.Literal(1, computation_types.TensorType(np.int16)), - building_blocks.Literal(1, computation_types.TensorType(np.int16)), - ), - ( - 'int32', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - ), - ( - 'int64', - building_blocks.Literal(1, computation_types.TensorType(np.int64)), - building_blocks.Literal(1, computation_types.TensorType(np.int64)), - ), - ( - 'uint8', - building_blocks.Literal(1, computation_types.TensorType(np.uint8)), - building_blocks.Literal(1, computation_types.TensorType(np.uint8)), - ), - ( - 'uint16', - building_blocks.Literal(1, computation_types.TensorType(np.uint16)), - building_blocks.Literal(1, computation_types.TensorType(np.uint16)), - ), - ( - 'uint32', - building_blocks.Literal(1, computation_types.TensorType(np.uint32)), - building_blocks.Literal(1, computation_types.TensorType(np.uint32)), - ), - ( - 'uint64', - building_blocks.Literal(1, computation_types.TensorType(np.uint64)), - building_blocks.Literal(1, computation_types.TensorType(np.uint64)), - ), - ( - 'float16', - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float16) - ), - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float16) - ), - ), - ( - 'float32', - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float32) - ), - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float32) - ), - ), - ( - 'float64', - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float64) - ), - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float64) - ), - ), - ( - 'complex64', - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex64) - ), - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex64) - ), - ), - ( - 'complex128', - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex128) - ), - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex128) - ), - ), - ( - 'str', - building_blocks.Literal('a', computation_types.TensorType(np.str_)), - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - ), - ( - 'bytes', - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - ), - ( - 'generic', - building_blocks.Literal( - np.int32(1), computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - np.int32(1), computation_types.TensorType(np.int32) - ), - ), - ( - 'array', - building_blocks.Literal( - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - building_blocks.Literal( - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - ), - ) - def test_eq_returns_true(self, literal, other): - self.assertIsNot(literal, other) - self.assertEqual(literal, other) - - @parameterized.named_parameters( - ( - 'different_value', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - building_blocks.Literal(2, computation_types.TensorType(np.int32)), - ), - ( - 'different_type_signature', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - building_blocks.Literal(1, computation_types.TensorType(np.int64)), - ), - ) - def test_eq_returns_false(self, literal, other): - self.assertIsNot(literal, other) - self.assertNotEqual(literal, other) - - @parameterized.named_parameters( - ( - 'bool', - building_blocks.Literal( - True, computation_types.TensorType(np.bool_) - ), - building_blocks.Literal( - True, computation_types.TensorType(np.bool_) - ), - ), - ( - 'int', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - ), - ( - 'uint', - building_blocks.Literal(1, computation_types.TensorType(np.uint32)), - building_blocks.Literal(1, computation_types.TensorType(np.uint32)), - ), - ( - 'float', - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float32) - ), - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float32) - ), - ), - ( - 'complex', - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex64) - ), - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex64) - ), - ), - ( - 'str', - building_blocks.Literal('a', computation_types.TensorType(np.str_)), - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - ), - ( - 'bytes', - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - ), - ( - 'generic', - building_blocks.Literal( - np.int32(1), computation_types.TensorType(np.int32) - ), - building_blocks.Literal( - np.int32(1), computation_types.TensorType(np.int32) - ), - ), - ( - 'array', - building_blocks.Literal( - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - building_blocks.Literal( - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - ), - ) - def test_hash_returns_same_value(self, literal, other): - self.assertEqual(hash(literal), hash(other)) - - @parameterized.named_parameters( - ( - 'different_value', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - building_blocks.Literal(2, computation_types.TensorType(np.int32)), - ), - ( - 'different_type_signature', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - building_blocks.Literal(1, computation_types.TensorType(np.int64)), - ), - ) - def test_hash_returns_different_value(self, literal, other): - self.assertNotEqual(literal, other) - self.assertNotEqual(hash(literal), hash(other)) - - @parameterized.named_parameters( - ( - 'bool', - building_blocks.Literal( - True, computation_types.TensorType(np.bool_) - ), - 'Literal(True, TensorType(np.bool_))', - ), - ( - 'int', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - 'Literal(1, TensorType(np.int32))', - ), - ( - 'float', - building_blocks.Literal( - 1.0, computation_types.TensorType(np.float32) - ), - 'Literal(1.0, TensorType(np.float32))', - ), - ( - 'complex', - building_blocks.Literal( - (1.0 + 1.0j), computation_types.TensorType(np.complex64) - ), - 'Literal((1+1j), TensorType(np.complex64))', - ), - ( - 'str', - building_blocks.Literal('a', computation_types.TensorType(np.str_)), - "Literal(b'a', TensorType(np.str_))", - ), - ( - 'bytes', - building_blocks.Literal(b'a', computation_types.TensorType(np.str_)), - "Literal(b'a', TensorType(np.str_))", - ), - ( - 'generic', - building_blocks.Literal( - np.int32(1), computation_types.TensorType(np.int32) - ), - 'Literal(1, TensorType(np.int32))', - ), - ( - 'array', - building_blocks.Literal( - np.array([1, 2, 3], np.int32), - computation_types.TensorType(np.int32, shape=[3]), - ), - ( - 'Literal(np.array([1, 2, 3], dtype=np.int32),' - ' TensorType(np.int32, (3,)))' - ), - ), - ) - def test_repr(self, literal, expected_repr): - self.assertEqual(repr(literal), expected_repr) - - -class RepresentationTest(absltest.TestCase): - - def test_returns_string_for_block(self): - data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ref = building_blocks.Reference('c', np.int32) - comp = building_blocks.Block((('a', data), ('b', data)), ref) - - self.assertEqual(comp.compact_representation(), '(let a=1,b=1 in c)') - # pyformat: disable - self.assertEqual( - comp.formatted_representation(), - '(let\n' - ' a=1,\n' - ' b=1\n' - ' in c)' - ) - self.assertEqual( - comp.structural_representation(), - ' Block\n' - ' / \\\n' - '[a=Lit(1), b=Lit(1)] Ref(c)' - ) - # pyformat: enable - - def test_returns_string_for_call_with_arg(self): - fn_type = computation_types.FunctionType(np.int32, np.int32) - fn = building_blocks.Reference('a', fn_type) - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - comp = building_blocks.Call(fn, arg) - - self.assertEqual(comp.compact_representation(), 'a(1)') - self.assertEqual(comp.formatted_representation(), 'a(1)') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - ' Call\n' - ' / \\\n' - 'Ref(a) Lit(1)' - ) - # pyformat: enable - - def test_returns_string_for_call_with_no_arg(self): - fn_type = computation_types.FunctionType(None, np.int32) - fn = building_blocks.Reference('a', fn_type) - comp = building_blocks.Call(fn) - - self.assertEqual(comp.compact_representation(), 'a()') - self.assertEqual(comp.formatted_representation(), 'a()') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - ' Call\n' - ' /\n' - 'Ref(a)' - ) - # pyformat: enable - - def test_returns_string_for_compiled_computation(self): - tensor_type = computation_types.TensorType(np.int32) - proto = computation_factory.create_lambda_identity(tensor_type) - comp = building_blocks.CompiledComputation( - proto, name='a', type_signature=tensor_type - ) - - self.assertEqual(comp.compact_representation(), 'comp#a') - self.assertEqual(comp.formatted_representation(), 'comp#a') - self.assertEqual(comp.structural_representation(), 'Compiled(a)') - - def test_returns_string_for_data(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3], np.int32) - ) - comp = building_blocks.Data(any_proto, np.int32) - - expected = str(id(any_proto)) - self.assertEqual(comp.compact_representation(), expected) - self.assertEqual(comp.formatted_representation(), expected) - self.assertEqual(comp.structural_representation(), f'Data({expected})') - - def test_returns_string_for_intrinsic(self): - comp_type = computation_types.TensorType(np.int32) - comp = building_blocks.Intrinsic('intrinsic', comp_type) - - self.assertEqual(comp.compact_representation(), 'intrinsic') - self.assertEqual(comp.formatted_representation(), 'intrinsic') - self.assertEqual(comp.structural_representation(), 'intrinsic') - - def test_returns_string_for_lambda(self): - ref = building_blocks.Reference('a', np.int32) - comp = building_blocks.Lambda(ref.name, ref.type_signature, ref) - - self.assertEqual(comp.compact_representation(), '(a -> a)') - self.assertEqual(comp.formatted_representation(), '(a -> a)') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - 'Lambda(a)\n' - '|\n' - 'Ref(a)' - ) - # pyformat: enable - - def test_returns_string_for_placement(self): - comp = building_blocks.Placement(placements.CLIENTS) - - self.assertEqual(comp.compact_representation(), 'CLIENTS') - self.assertEqual(comp.formatted_representation(), 'CLIENTS') - self.assertEqual(comp.structural_representation(), 'Placement') - - def test_returns_string_for_reference(self): - comp = building_blocks.Reference('a', np.int32) - - self.assertEqual(comp.compact_representation(), 'a') - self.assertEqual(comp.formatted_representation(), 'a') - self.assertEqual(comp.structural_representation(), 'Ref(a)') - - def test_returns_string_for_selection_with_name(self): - ref = building_blocks.Reference('a', (('b', np.int32), ('c', np.float64))) - comp = building_blocks.Selection(ref, name='b') - - self.assertEqual(comp.compact_representation(), 'a.b') - self.assertEqual(comp.formatted_representation(), 'a.b') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - 'Sel(b)\n' - '|\n' - 'Ref(a)' - ) - # pyformat: enable - - def test_returns_string_for_selection_with_index(self): - ref = building_blocks.Reference('a', (('b', np.int32), ('c', np.float64))) - comp = building_blocks.Selection(ref, index=0) - - self.assertEqual(comp.compact_representation(), 'a[0]') - self.assertEqual(comp.formatted_representation(), 'a[0]') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - 'Sel(0)\n' - '|\n' - 'Ref(a)' - ) - # pyformat: enable - - def test_returns_string_for_struct_with_names(self): - literally = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) - ) - comp = building_blocks.Struct([('a', literally), ('b', literally)]) - - self.assertEqual(comp.compact_representation(), '') - # pyformat: disable - self.assertEqual( - comp.formatted_representation(), - '<\n' - ' a=2,\n' - ' b=2\n' - '>' - ) - self.assertEqual( - comp.structural_representation(), - 'Struct\n' - '|\n' - '[a=Lit(2), b=Lit(2)]' - ) - # pyformat: enable - - def test_returns_string_for_struct_with_no_names(self): - data = building_blocks.Literal(3, computation_types.TensorType(np.int32)) - comp = building_blocks.Struct([data, data]) - - self.assertEqual(comp.compact_representation(), '<3,3>') - # pyformat: disable - self.assertEqual( - comp.formatted_representation(), - '<\n' - ' 3,\n' - ' 3\n' - '>' - ) - self.assertEqual( - comp.structural_representation(), - 'Struct\n' - '|\n' - '[Lit(3), Lit(3)]' - ) - # pyformat: enable - - def test_returns_string_for_struct_with_no_elements(self): - comp = building_blocks.Struct([]) - - self.assertEqual(comp.compact_representation(), '<>') - self.assertEqual(comp.formatted_representation(), '<>') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - 'Struct\n' - '|\n' - '[]' - ) - # pyformat: enable - - def test_returns_string_for_federated_aggregate(self): - comp = building_block_test_utils.create_whimsy_called_federated_aggregate( - accumulate_parameter_name='a', - merge_parameter_name='b', - report_parameter_name='c', - ) - - self.assertEqual( - comp.compact_representation(), - 'federated_aggregate( 1),(b -> 1)' - ',(c -> 1)>)', - ) - # pyformat: disable - self.assertEqual( - comp.formatted_representation(), - 'federated_aggregate(<\n' - ' federated_value_at_clients(1),\n' - ' 1,\n' - ' (a -> 1),\n' - ' (b -> 1),\n' - ' (c -> 1)\n' - '>)' - ) - self.assertEqual( - comp.structural_representation(), - ' Call\n' - ' / \\\n' - ' federated_aggregate Struct\n' - ' |\n' - ' [Call, Lit(1), Lambda(a), Lambda(b), ' - 'Lambda(c)]\n' - ' / \\ | | |\n' - 'federated_value_at_clients Lit(1) Lit(1) Lit(1) Lit(1)' - ) - # pyformat: enable - - def test_returns_string_for_federated_map(self): - comp = building_block_test_utils.create_whimsy_called_federated_map( - parameter_name='a' - ) - - self.assertEqual( - comp.compact_representation(), - 'federated_map(<(a -> a),federated_value_at_clients(1)>)', - ) - # pyformat: disable - self.assertEqual( - comp.formatted_representation(), - 'federated_map(<\n' - ' (a -> a),\n' - ' federated_value_at_clients(1)\n' - '>)' - ) - self.assertEqual( - comp.structural_representation(), - ' Call\n' - ' / \\\n' - 'federated_map Struct\n' - ' |\n' - ' [Lambda(a), Call]\n' - ' | / \\\n' - ' Ref(a) federated_value_at_clients Lit(1)' - ) - # pyformat: enable - - def test_returns_string_for_comp_with_left_overhang(self): - fn_1_type = computation_types.FunctionType(np.int32, np.int32) - fn_1 = building_blocks.Reference('a', fn_1_type) - fn_2_type = computation_types.FunctionType(None, np.int32) - fn_2 = building_blocks.Reference('bbbbbbbbbb', fn_2_type) - arg = building_blocks.Call(fn_2) - comp = building_blocks.Call(fn_1, arg) - - self.assertEqual(comp.compact_representation(), 'a(bbbbbbbbbb())') - self.assertEqual(comp.formatted_representation(), 'a(bbbbbbbbbb())') - # pyformat: disable - self.assertEqual( - comp.structural_representation(), - ' Call\n' - ' / \\\n' - ' Ref(a) Call\n' - ' /\n' - 'Ref(bbbbbbbbbb)' - ) - # pyformat: enable - - def test_returns_string_for_comp_with_right_overhang(self): - ref = building_blocks.Reference('a', np.int32) - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - tup = building_blocks.Struct([ref, lit, lit, lit, lit]) - sel = building_blocks.Selection(tup, index=0) - fn = building_blocks.Lambda(ref.name, ref.type_signature, sel) - comp = building_blocks.Call(fn, lit) - - self.assertEqual(comp.compact_representation(), '(a -> [0])(1)') - # pyformat: disable - self.assertEqual( - comp.formatted_representation(), - '(a -> <\n' - ' a,\n' - ' 1,\n' - ' 1,\n' - ' 1,\n' - ' 1\n' - '>[0])(1)' - ) - self.assertEqual( - comp.structural_representation(), - ' Call\n' - ' / \\\n' - 'Lambda(a) Lit(1)\n' - '|\n' - 'Sel(0)\n' - '|\n' - 'Struct\n' - '|\n' - '[Ref(a), Lit(1), Lit(1), Lit(1), Lit(1)]' - ) - # pyformat: enable - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/compiler_test_utils.py b/tensorflow_federated/python/core/impl/compiler/compiler_test_utils.py index b710019093..282da2ff51 100644 --- a/tensorflow_federated/python/core/impl/compiler/compiler_test_utils.py +++ b/tensorflow_federated/python/core/impl/compiler/compiler_test_utils.py @@ -15,41 +15,43 @@ import collections +import federated_language + from tensorflow_federated.python.common_libs import golden from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import transformation_utils # Name the compiled computations to avoid the issue that the TF graphs being # generated are different at HEAD vs in OSS, resulting in different hash values # for the computation name which fail to compare. def _name_compiled_computations( - tree: building_blocks.ComputationBuildingBlock, -) -> building_blocks.ComputationBuildingBlock: + tree: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.ComputationBuildingBlock: """Name the compiled computations.""" counter = 1 def _transform(building_block): nonlocal counter - if isinstance(building_block, building_blocks.CompiledComputation): + if isinstance( + building_block, federated_language.framework.CompiledComputation + ): new_name = str(counter) counter += 1 return ( - building_blocks.CompiledComputation( + federated_language.framework.CompiledComputation( proto=building_block.proto, name=new_name ), True, ) return building_block, False - return transformation_utils.transform_postorder(tree, _transform)[0] + return federated_language.framework.transform_postorder(tree, _transform)[0] def check_computations( filename: str, computations: collections.OrderedDict[ - str, building_blocks.ComputationBuildingBlock + str, federated_language.framework.ComputationBuildingBlock ], ) -> None: """Check the AST of computations matches the contents of the golden file. @@ -57,7 +59,7 @@ def check_computations( Args: filename: String filename of the golden file. computations: An `collections.OrderedDict` of computation names to - `building_blocks.ComputationBuildingBlock`. + `federated_language.framework.ComputationBuildingBlock`. Raises: TypeError: If any argument type mismatches. @@ -67,7 +69,7 @@ def check_computations( values = [] for name, computation in computations.items(): py_typecheck.check_type( - computation, building_blocks.ComputationBuildingBlock, name + computation, federated_language.framework.ComputationBuildingBlock, name ) computation_ast = _name_compiled_computations(computation) values.append( diff --git a/tensorflow_federated/python/core/impl/compiler/computation_factory.py b/tensorflow_federated/python/core/impl/compiler/computation_factory.py deleted file mode 100644 index 4c51d0be63..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/computation_factory.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of construction functions for computation structures.""" - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_factory -from tensorflow_federated.python.core.impl.types import type_serialization - - -def create_lambda_empty_struct() -> pb.Computation: - """Returns a lambda computation returning an empty struct. - - Has the type signature: - - ( -> <>) - - Returns: - An instance of `pb.Computation`. - """ - result_type = computation_types.StructType([]) - type_signature = computation_types.FunctionType(None, result_type) - result = pb.Computation( - type=type_serialization.serialize_type(result_type), - struct=pb.Struct(element=[]), - ) - fn = pb.Lambda(parameter_name=None, result=result) - # We are unpacking the lambda argument here because `lambda` is a reserved - # keyword in Python, but it is also the name of the parameter for a - # `pb.Computation`. - # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts - return pb.Computation( - type=type_serialization.serialize_type(type_signature), **{'lambda': fn} - ) # pytype: disable=wrong-keyword-args - - -def create_lambda_identity(type_spec: computation_types.Type) -> pb.Computation: - """Returns a lambda computation representing an identity function. - - Has the type signature: - - (T -> T) - - Args: - type_spec: A `computation_types.Type`. - - Returns: - An instance of `pb.Computation`. - """ - type_signature = type_factory.unary_op(type_spec) - result = pb.Computation( - type=type_serialization.serialize_type(type_spec), - reference=pb.Reference(name='a'), - ) - fn = pb.Lambda(parameter_name='a', result=result) - # We are unpacking the lambda argument here because `lambda` is a reserved - # keyword in Python, but it is also the name of the parameter for a - # `pb.Computation`. - # https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts - return pb.Computation( - type=type_serialization.serialize_type(type_signature), **{'lambda': fn} - ) # pytype: disable=wrong-keyword-args diff --git a/tensorflow_federated/python/core/impl/compiler/computation_factory_test.py b/tensorflow_federated/python/core/impl/compiler/computation_factory_test.py deleted file mode 100644 index a09779d6f0..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/computation_factory_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_factory -from tensorflow_federated.python.core.impl.types import type_serialization - - -class CreateLambdaEmptyTupleTest(absltest.TestCase): - - def test_returns_computation(self): - proto = computation_factory.create_lambda_empty_struct() - - self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = computation_types.FunctionType( - None, computation_types.StructType(()) - ) - self.assertEqual(actual_type, expected_type) - - -class CreateLambdaIdentityTest(absltest.TestCase): - - def test_returns_computation_int(self): - type_signature = computation_types.TensorType(np.int32) - - proto = computation_factory.create_lambda_identity(type_signature) - - self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = type_factory.unary_op(type_signature) - self.assertEqual(actual_type, expected_type) - - def test_returns_computation_tuple_unnamed(self): - type_signature = computation_types.StructType([np.int32, np.float32]) - - proto = computation_factory.create_lambda_identity(type_signature) - - self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = type_factory.unary_op(type_signature) - self.assertEqual(actual_type, expected_type) - - def test_returns_computation_tuple_named(self): - type_signature = computation_types.StructType( - [('a', np.int32), ('b', np.float32)] - ) - - proto = computation_factory.create_lambda_identity(type_signature) - - self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = type_factory.unary_op(type_signature) - self.assertEqual(actual_type, expected_type) - - def test_returns_computation_sequence(self): - type_signature = computation_types.SequenceType(np.int32) - - proto = computation_factory.create_lambda_identity(type_signature) - - self.assertIsInstance(proto, pb.Computation) - actual_type = type_serialization.deserialize_type(proto.type) - expected_type = type_factory.unary_op(type_signature) - self.assertEqual(actual_type, expected_type) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py b/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py deleted file mode 100644 index a9eeb65055..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/intrinsic_defs.py +++ /dev/null @@ -1,742 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Definitions of all intrinsic for use within the system.""" - -import enum -from typing import Optional - -import numpy as np - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_factory - -_intrinsic_registry = {} - - -@enum.unique -class BroadcastKind(enum.Enum): - DEFAULT = 1 - SECURE = 2 - - -@enum.unique -class AggregationKind(enum.Enum): - DEFAULT = 1 - SECURE = 2 - - -class IntrinsicDef: - """Represents the definition of an intrinsic. - - This class represents the ultimate source of ground truth about what kinds of - intrinsics exist and what their types are. To be consuled by all other code - that deals with intrinsics. - """ - - def __init__( - self, - name: str, - uri: str, - type_signature: computation_types.Type, - aggregation_kind: Optional[AggregationKind] = None, - broadcast_kind: Optional[BroadcastKind] = None, - ): - """Constructs a definition of an intrinsic. - - Args: - name: The short human-friendly name of this intrinsic. - uri: The URI of this intrinsic. - type_signature: The type of the intrinsic. - aggregation_kind: Optional kind of aggregation performed by calls. - broadcast_kind: Optional kind of broadcast performed by calls. - """ - py_typecheck.check_type(name, str) - py_typecheck.check_type(uri, str) - py_typecheck.check_type(type_signature, computation_types.Type) - self._name = str(name) - self._uri = str(uri) - self._type_signature = type_signature - self._aggregation_kind = aggregation_kind - self._broadcast_kind = broadcast_kind - _intrinsic_registry[str(uri)] = self - - # TODO: b/113112885 - Add support for an optional type checking function that - # can verify whether this intrinsic is applicable to given kinds of arguments, - # e.g., to allow sum-like functions to be applied only to arguments that are - # composed of tensors as leaves of a possible nested structure. - - @property - def name(self): - return self._name - - @property - def uri(self): - return self._uri - - @property - def type_signature(self): - return self._type_signature - - @property - def aggregation_kind(self) -> Optional[AggregationKind]: - return self._aggregation_kind - - @property - def broadcast_kind(self) -> Optional[BroadcastKind]: - return self._broadcast_kind - - def __str__(self): - return self._name - - def __repr__(self): - return "IntrinsicDef('{}')".format(self._uri) - - -# TODO: b/113112885 - Perhaps add a way for these to get auto-registered to -# enable things like lookup by URI, etc., similarly to how it's handled in the -# placements.py. - -# TODO: b/113112885 - Define the generic equivalents of all operators below, -# i.e., intrinsics that support arbitrary placements, to allow the federated -# operators to be decomposed into expressions that might involve one or more -# layers of intermediate aggregators. The type signatures of these generic -# intrinsics are tentatively defined as follows: -# -# - Place an unplaced value: -# -# generic_place: -> T@p -# -# - Compute an aggregate using the 4-part aggregation operator interface: -# -# generic_aggregate: <{T}@p,U,(->U),(->U),(U->R),q> -> R@q -# -# - Compute an unweighted average: -# -# generic_average: <{T}@p,q> -> T@q -# -# - Broadcast an item: -# -# generic_broadcast: -> T@q -# -# - Materialize a federated value as a set of sequences at another placement, -# with the participants at 'q' collecting from disjoint subsets of 'p' that -# jointly cover all of 'p'. -# -# generic_partial_collect: <{T}@p,q> -> {T*}@q -# -# - Materialize a federated value as a single sequence: -# -# generic_collect: <{T}@p,q> -> T*@q -# -# - Pointwise mapping of constituents of a federated value: -# -# generic_map: <(T->U),{T}@p> -> {U}@p -# -# - Pointwise mapping of all-equal constituents of a federated value: -# -# generic_apply: <(T->U),T@p> -> U@p -# -# - Perform one-stage set reduction of a federated value with a given operator, -# with the participants at 'q' reducing over disjoint subsets of 'p' that -# jointly cover all of 'p'. -# -# generic_partial_reduce: <{T}@p,U,(->U),q> -> {U}@q -# -# - Perform complete set reduction of a federated value with a given operator: -# -# generic_reduce: <{T}@p,U,(->U),q> -> U@q -# -# - Select and agree on a single member consistuent of a federated value (this -# is technically need to project {T}@SERVER to T@SERVER in a manner that is -# formally consistent; a technicality that we do not expect to surface in the -# user API). -# -# generic_only: {T}@p -> T@p -# -# - Compute a partial sum of a value (for values with numeric constituents): -# -# generic_partial_sum: <{T}@p,q> -> {T}@q -# -# - Compute a simple sum of a value (for values with numeric constituents): -# -# generic_sum: <{T}@p,q> -> T@q -# -# - Compute an average weighted by a numeric non-complex scalar: -# -# generic_weighted_average: <{T}@p,{U}@p,q> -> T@q -# -# - Transform a pair of federated values into a federated pair (technicality we -# expect to bury through implicit conversions, TBD). -# -# generic_zip: <{T}@p,{U}@p> -> {}@p - -# Computes an aggregate of client items (the first, {T}@CLIENTS-typed parameter) -# using a multi-stage process in which client items are first partially -# aggregated at an intermediate layer, then the partial aggregates are further -# combined, and finally projected into the result. This multi-stage process is -# parameterized by a four-part aggregation interface that consists of the -# following: -# a) The 'zero' in the algebra used at the initial stage (partial aggregation), -# This is the second, U-typed parameter. -# b) The operator that accumulates T-typed client items into the U-typed partial -# aggregates. This is the third, (->U)-typed parameter. -# c) The operator that combines pairs of U-typed partial aggregates. This is the -# fourth, (->U)-typed parameter. -# d) The operator that projects the top-level aggregate into the final result. -# This is the fifth, (U->R)-typed parameter. -# -# Conceptually, given a new literal INTERMEDIATE_AGGREGATORS in a single-layer -# aggregation architecture, one could define this intrinsic in terms of generic -# intrinsics defined above, as follows. -# -# @federated_computation -# def federated_aggregate(x, zero, accumulate, merge, report): -# a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS) -# b = generic_reduce(a, zero, merge, SERVER) -# c = generic_map(report, b) -# return c -# -# Actual implementations might vary. -# -# Type signature: <{T}@CLIENTS,U,(->U),(->U),(U->R)> -> R@SERVER -FEDERATED_AGGREGATE = IntrinsicDef( - 'FEDERATED_AGGREGATE', - 'federated_aggregate', - computation_types.FunctionType( - parameter=[ - computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - computation_types.AbstractType('U'), - type_factory.reduction_op( - computation_types.AbstractType('U'), - computation_types.AbstractType('T'), - ), - type_factory.binary_op(computation_types.AbstractType('U')), - computation_types.FunctionType( - computation_types.AbstractType('U'), - computation_types.AbstractType('R'), - ), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('R'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.DEFAULT, -) - -# Applies a given function to a value on the server. -# -# Type signature: <(T->U),T@SERVER> -> U@SERVER -FEDERATED_APPLY = IntrinsicDef( - 'FEDERATED_APPLY', - 'federated_apply', - computation_types.FunctionType( - parameter=[ - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ), - computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('U'), placements.SERVER - ), - ), -) - -# Broadcasts a server item to all clients. -# -# Type signature: T@SERVER -> T@CLIENTS -FEDERATED_BROADCAST = IntrinsicDef( - 'FEDERATED_BROADCAST', - 'federated_broadcast', - computation_types.FunctionType( - parameter=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), - placements.CLIENTS, - all_equal=True, - ), - ), - broadcast_kind=BroadcastKind.DEFAULT, -) - -# Evaluates a function at the clients. -# -# Type signature: (() -> T) -> {T}@CLIENTS -FEDERATED_EVAL_AT_CLIENTS = IntrinsicDef( - 'FEDERATED_EVAL_AT_CLIENTS', - 'federated_eval_at_clients', - computation_types.FunctionType( - parameter=computation_types.FunctionType( - None, computation_types.AbstractType('T') - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - ), -) - -# Evaluates a function at the server. -# -# Type signature: (() -> T) -> T@SERVER -FEDERATED_EVAL_AT_SERVER = IntrinsicDef( - 'FEDERATED_EVAL_AT_SERVER', - 'federated_eval_at_server', - computation_types.FunctionType( - parameter=computation_types.FunctionType( - None, computation_types.AbstractType('T') - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), -) - -# Maps member constituents of a client value pointwise using a given mapping -# function that operates independently on each client. -# -# Type signature: <(T->U),{T}@CLIENTS> -> {U}@CLIENTS -FEDERATED_MAP = IntrinsicDef( - 'FEDERATED_MAP', - 'federated_map', - computation_types.FunctionType( - parameter=[ - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ), - computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('U'), placements.CLIENTS - ), - ), -) - -# Maps member constituents of a client all equal value pointwise using a given -# mapping function that operates independently on each client, as a result of -# this independence, the value is only garunteed to be all equal if the function -# is deterministic. -# -# Type signature: <(T->U),T@CLIENTS> -> U@CLIENTS -FEDERATED_MAP_ALL_EQUAL = IntrinsicDef( - 'FEDERATED_MAP_ALL_EQUAL', - 'federated_map_all_equal', - computation_types.FunctionType( - parameter=[ - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ), - computation_types.FederatedType( - computation_types.AbstractType('T'), - placements.CLIENTS, - all_equal=True, - ), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('U'), - placements.CLIENTS, - all_equal=True, - ), - ), -) - -# Computes a simple (equally weighted) mean of client items. Only supported -# for numeric tensor types, or composite structures made up of numeric types. -# -# Type signature: {T}@CLIENTS -> T@SERVER -FEDERATED_MEAN = IntrinsicDef( - 'FEDERATED_MEAN', - 'federated_mean', - computation_types.FunctionType( - parameter=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.DEFAULT, -) - -# Computes the min of client values on the server. Only supported for numeric -# types, or nested structures made up of numeric computation_types. -# -# Type signature: {T}@CLIENTS -> T@SERVER -FEDERATED_MIN = IntrinsicDef( - 'FEDERATED_MIN', - 'federated_min', - computation_types.FunctionType( - parameter=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.DEFAULT, -) - -# Computes the max of client values on the server. Only supported for numeric -# types, or nested structures made up of numeric computation_types. -# -# Type signature: {T}@CLIENTS -> T@SERVER -FEDERATED_MAX = IntrinsicDef( - 'FEDERATED_MAX', - 'federated_max', - computation_types.FunctionType( - parameter=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.DEFAULT, -) - -# Computes the sum of client values on the server, securely. Only supported for -# integers or nested structures of integers. -# -# Type signature: <{V}@CLIENTS,M> -> V@SERVER -FEDERATED_SECURE_SUM = IntrinsicDef( - 'FEDERATED_SECURE_SUM', - 'federated_secure_sum', - computation_types.FunctionType( - parameter=[ - computation_types.FederatedType( - computation_types.AbstractType('V'), placements.CLIENTS - ), - computation_types.AbstractType('M'), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('V'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.SECURE, -) - -# Computes the sum of client values on the server, securely. Only supported for -# integers or nested structures of integers. -# -# Type signature: <{V}@CLIENTS,B> -> V@SERVER -FEDERATED_SECURE_SUM_BITWIDTH = IntrinsicDef( - 'FEDERATED_SECURE_SUM_BITWIDTH', - 'federated_secure_sum_bitwidth', - computation_types.FunctionType( - parameter=[ - computation_types.FederatedType( - computation_types.AbstractType('V'), placements.CLIENTS - ), - computation_types.AbstractType('B'), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('V'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.SECURE, -) - -_SELECT_TYPE = computation_types.FunctionType( - parameter=[ - computation_types.FederatedType( - computation_types.AbstractType('Ks'), placements.CLIENTS - ), # client_keys - computation_types.FederatedType(np.int32, placements.SERVER), # max_key - computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), # server_state - computation_types.FunctionType( - [computation_types.AbstractType('T'), np.int32], - computation_types.AbstractType('U'), - ), # select_fn - ], - result=computation_types.FederatedType( - computation_types.SequenceType(computation_types.AbstractType('U')), - placements.CLIENTS, - ), -) - -# Distributes server values to clients based on client keys. -FEDERATED_SELECT = IntrinsicDef( - 'FEDERATED_SELECT', - 'federated_select', - _SELECT_TYPE, - broadcast_kind=BroadcastKind.DEFAULT, -) - -# Securely distributes server values to clients based on private client keys. -FEDERATED_SECURE_SELECT = IntrinsicDef( - 'FEDERATED_SECURE_SELECT', - 'federated_secure_select', - _SELECT_TYPE, - broadcast_kind=BroadcastKind.SECURE, -) - -# Computes the sum of client values on the server. Only supported for numeric -# types, or nested structures made up of numeric computation_types. -# -# Type signature: {T}@CLIENTS -> T@SERVER -FEDERATED_SUM = IntrinsicDef( - 'FEDERATED_SUM', - 'federated_sum', - computation_types.FunctionType( - parameter=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.DEFAULT, -) - -# Places a value at the clients. -# -# Type signature: T -> T@CLIENTS -FEDERATED_VALUE_AT_CLIENTS = IntrinsicDef( - 'FEDERATED_VALUE_AT_CLIENTS', - 'federated_value_at_clients', - computation_types.FunctionType( - parameter=computation_types.AbstractType('T'), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS, True - ), - ), -) - -# Places a value at the server. -# -# Type signature: T -> T@SERVER -FEDERATED_VALUE_AT_SERVER = IntrinsicDef( - 'FEDERATED_VALUE_AT_SERVER', - 'federated_value_at_server', - computation_types.FunctionType( - parameter=computation_types.AbstractType('T'), - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), -) - -# Computes a weighted mean of client items. Only supported for numeric tensor -# types, or composite structures made up of numeric types. Weights must be -# simple scalar numeric (integer or floating point, not complex) tensor types. -# The types of weights and values must be compatible, i.e., multiplying and -# dividing member constituents of the value by weights should yield results of -# the same type as the type of these member constituents being weighted. Thus, -# for example, one may not supply values containing `np.int32`` tensors, as the -# result of weighting such values is of a floating-point type. Casting must be -# injected, where appropriate, by the compiler before invoking this intrinsic. -# -# Type signature: <{T}@CLIENTS,{U}@CLIENTS> -> T@SERVER -FEDERATED_WEIGHTED_MEAN = IntrinsicDef( - 'FEDERATED_WEIGHTED_MEAN', - 'federated_weighted_mean', - computation_types.FunctionType( - parameter=[ - computation_types.FederatedType( - computation_types.AbstractType('T'), placements.CLIENTS - ), - computation_types.FederatedType( - computation_types.AbstractType('U'), placements.CLIENTS - ), - ], - result=computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ), - ), - aggregation_kind=AggregationKind.DEFAULT, -) - -# Zips a tuple of two federated types into a federated tuple. -# -# Type signature: T -> U@CLIENTS -# where `T` is a structure of client-placed values. -# `U` must be identical to `T` with federated placement removed. -FEDERATED_ZIP_AT_CLIENTS = IntrinsicDef( - 'FEDERATED_ZIP_AT_CLIENTS', - 'federated_zip_at_clients', - computation_types.FunctionType( - parameter=computation_types.AbstractType('T'), - result=computation_types.FederatedType( - computation_types.AbstractType('U'), placements.CLIENTS - ), - ), -) -# Type signature: T -> U@SERVER -# where `T` is a structure of server-placed values. -# `U` must be identical to `T` with federated placement removed. -FEDERATED_ZIP_AT_SERVER = IntrinsicDef( - 'FEDERATED_ZIP_AT_SERVER', - 'federated_zip_at_server', - computation_types.FunctionType( - parameter=computation_types.AbstractType('T'), - result=computation_types.FederatedType( - computation_types.AbstractType('U'), placements.SERVER - ), - ), -) - -# TODO: b/122728050 - Define GENERIC_DIVIDE, GENERIC_MULTIPLY, and GENERIC_ONE -# to support intrinsic reductions (see the uses in intrinsic_bodies.py for -# the motivation and usage in support of which we need to define semantics). - -# Generic plus operator that accepts a variety of types in federated computation -# context. The range of types 'T' supported to be defined. It should work in a -# natural manner for tensors, tuples, federated types, possibly sequences, and -# maybe even functions (although it's unclear if such generality is desirable). -# -# TODO: b/113123410 - Define the range of supported computation_types. -# -# Type signature: -> T -GENERIC_PLUS = IntrinsicDef( - 'GENERIC_PLUS', - 'generic_plus', - type_factory.binary_op(computation_types.AbstractType('T')), -) - -# Performs pointwise TensorFlow division on its arguments. -# The type signature of generic divide is determined by TensorFlow's set of -# implicit type equations. For example, dividing `int32` by `int32` in TF -# generates a tensor of type `float64`. There is therefore more structure than -# is suggested by the type signature ` -> U`. -# Type signature: -> U -GENERIC_DIVIDE = IntrinsicDef( - 'GENERIC_DIVIDE', - 'generic_divide', - computation_types.FunctionType( - [ - computation_types.AbstractType('T'), - computation_types.AbstractType('T'), - ], - computation_types.AbstractType('U'), - ), -) - -# Performs pointwise TensorFlow multiplication on its arguments. -# Type signature: -> T -GENERIC_MULTIPLY = IntrinsicDef( - 'GENERIC_MULTIPLY', - 'generic_multiply', - computation_types.FunctionType( - [computation_types.AbstractType('T')] * 2, - computation_types.AbstractType('T'), - ), -) -# Generic zero operator that represents zero-filled values of diverse types (to -# be defined, but generally similar to that supported by GENERIC_ADD). -# -# TODO: b/113123410 - Define the range of supported computation_types. -# -# Type signature: T -GENERIC_ZERO = IntrinsicDef( - 'GENERIC_ZERO', 'generic_zero', computation_types.AbstractType('T') -) - -# Maps elements of a sequence using a given mapping function that operates -# independently on each element. -# -# Type signature: <(T->U),T*> -> U* -SEQUENCE_MAP = IntrinsicDef( - 'SEQUENCE_MAP', - 'sequence_map', - computation_types.FunctionType( - parameter=[ - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ), - computation_types.SequenceType(computation_types.AbstractType('T')), - ], - result=computation_types.SequenceType( - computation_types.AbstractType('U') - ), - ), -) - -# Reduces a sequence using a given 'zero' in the algebra (i.e., the result of -# reducing an empty sequence) and a given reduction operator with the signature -# U,T->U that incorporates a single T-typed element into a U-typed result of -# partial reduction. In the special case of T = U, this corresponds to the -# classical notion of reduction of a set using a commutative associative binary -# operator. The generalized reduction operator (with T != U) must yield the same -# results when repeatedly applied on sequences of elements in any order. -# -# Type signature: ->U)> -> U -SEQUENCE_REDUCE = IntrinsicDef( - 'SEQUENCE_REDUCE', - 'sequence_reduce', - computation_types.FunctionType( - parameter=[ - computation_types.SequenceType(computation_types.AbstractType('T')), - computation_types.AbstractType('U'), - type_factory.reduction_op( - computation_types.AbstractType('U'), - computation_types.AbstractType('T'), - ), - ], - result=computation_types.AbstractType('U'), - ), -) - -# Computes the sum of values in a sequence. Only supported for numeric types -# or nested structures made up of numeric types. -# -# Type signature: T* -> T -SEQUENCE_SUM = IntrinsicDef( - 'SEQUENCE_SUM', - 'sequence_sum', - computation_types.FunctionType( - parameter=computation_types.SequenceType( - computation_types.AbstractType('T') - ), - result=computation_types.AbstractType('T'), - ), -) - - -def uri_to_intrinsic_def(uri) -> Optional[IntrinsicDef]: - return _intrinsic_registry.get(uri) - - -# TODO: b/254770431 - Add documentation explaining the implications of setting -# broadcast_kind for an intrinsic. -def get_broadcast_intrinsics() -> list[IntrinsicDef]: - return [ - intrinsic - for intrinsic in _intrinsic_registry.values() - if intrinsic.broadcast_kind - ] - - -# TODO: b/254770431 - Add documentation explaining the implications of setting -# aggregation_kind for an intrinsic. -def get_aggregation_intrinsics() -> list[IntrinsicDef]: - return [ - intrinsic - for intrinsic in _intrinsic_registry.values() - if intrinsic.aggregation_kind - ] diff --git a/tensorflow_federated/python/core/impl/compiler/intrinsic_defs_test.py b/tensorflow_federated/python/core/impl/compiler/intrinsic_defs_test.py deleted file mode 100644 index fadadee701..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/intrinsic_defs_test.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized - -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs - - -def _get_intrinsic_named_parameters(): - def _predicate(obj): - return isinstance(obj, intrinsic_defs.IntrinsicDef) - - objects = [getattr(intrinsic_defs, x) for x in dir(intrinsic_defs)] - intrinsics = filter(_predicate, objects) - return [(x.name, x) for x in intrinsics] - - -class IntrinsicDefsTest(parameterized.TestCase): - - @parameterized.named_parameters(*_get_intrinsic_named_parameters()) - def test_names_match_those_in_module(self, intrinsic): - self.assertEqual(intrinsic, getattr(intrinsic_defs, intrinsic.name)) - - def test_uris_are_unique(self): - uris = set([x.uri for _, x in _get_intrinsic_named_parameters()]) - expected_length = len(_get_intrinsic_named_parameters()) - self.assertLen(uris, expected_length) - - @parameterized.named_parameters( - ('federated_broadcast', 'FEDERATED_BROADCAST', '(T@SERVER -> T@CLIENTS)'), - ( - 'federated_eval_at_clients', - 'FEDERATED_EVAL_AT_CLIENTS', - '(( -> T) -> {T}@CLIENTS)', - ), - ( - 'federated_eval_at_server', - 'FEDERATED_EVAL_AT_SERVER', - '(( -> T) -> T@SERVER)', - ), - ( - 'federated_map', - 'FEDERATED_MAP', - '(<(T -> U),{T}@CLIENTS> -> {U}@CLIENTS)', - ), - ( - 'federated_secure_sum', - 'FEDERATED_SECURE_SUM', - '(<{V}@CLIENTS,M> -> V@SERVER)', - ), - ( - 'federated_secure_sum_bitwidth', - 'FEDERATED_SECURE_SUM_BITWIDTH', - '(<{V}@CLIENTS,B> -> V@SERVER)', - ), - ( - 'federated_secure_select', - 'FEDERATED_SECURE_SELECT', - ( - '(<{Ks}@CLIENTS,int32@SERVER,T@SERVER,( -> U)> ->' - ' {U*}@CLIENTS)' - ), - ), - ( - 'federated_select', - 'FEDERATED_SELECT', - ( - '(<{Ks}@CLIENTS,int32@SERVER,T@SERVER,( -> U)> ->' - ' {U*}@CLIENTS)' - ), - ), - ('federated_sum', 'FEDERATED_SUM', '({T}@CLIENTS -> T@SERVER)'), - ( - 'federated_zip_at_clients', - 'FEDERATED_ZIP_AT_CLIENTS', - '(T -> {U}@CLIENTS)', - ), - ('federated_zip_at_server', 'FEDERATED_ZIP_AT_SERVER', '(T -> U@SERVER)'), - ) - def test_type_signature_strings(self, name, type_str): - intrinsic = getattr(intrinsic_defs, name) - self.assertEqual( - intrinsic.type_signature.compact_representation(), type_str - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/transformation_utils.py b/tensorflow_federated/python/core/impl/compiler/transformation_utils.py deleted file mode 100644 index 6d40642a38..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/transformation_utils.py +++ /dev/null @@ -1,1249 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of transformation utilities.""" - -import abc -import collections -from collections.abc import Callable -import itertools -import operator -import typing - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_blocks - - -def transform_postorder(comp, transform): - """Traverses `comp` recursively postorder and replaces its constituents. - - For each element of `comp` viewed as an expression tree, the transformation - `transform` is applied first to building blocks it is parameterized by, then - the element itself. The transformation `transform` should act as an identity - function on the kinds of elements (computation building blocks) it does not - care to transform. This corresponds to a post-order traversal of the - expression tree, i.e., parameters are always transformed left-to-right (in - the order in which they are listed in building block constructors), then the - parent is visited and transformed with the already-visited, and possibly - transformed arguments in place. - - Note: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`. - Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally, - `Call(f',x')` is transformed at the end. - - Args: - comp: A `computation_building_block.ComputationBuildingBlock` to traverse - and transform bottom-up. - transform: The transformation to apply locally to each building block in - `comp`. It is a Python function that accepts a building block at input, - and should return a (building block, bool) tuple as output, where the - building block is a `computation_building_block.ComputationBuildingBlock` - representing either the original building block or a transformed building - block and the bool is a flag indicating if the building block was modified - as. - - Returns: - The result of applying `transform` to parts of `comp` in a bottom-up - fashion, along with a Boolean with the value `True` if `comp` was - transformed and `False` if it was not. - - Raises: - TypeError: If the arguments are of the wrong computation_types. - NotImplementedError: If the argument is a kind of computation building block - that is currently not recognized. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - if isinstance( - comp, - ( - building_blocks.CompiledComputation, - building_blocks.Data, - building_blocks.Intrinsic, - building_blocks.Literal, - building_blocks.Placement, - building_blocks.Reference, - ), - ): - return transform(comp) - elif isinstance(comp, building_blocks.Selection): - source, source_modified = transform_postorder(comp.source, transform) - if source_modified: - comp = building_blocks.Selection(source, comp.name, comp.index) - comp, comp_modified = transform(comp) - return comp, comp_modified or source_modified - elif isinstance(comp, building_blocks.Struct): - elements = [] - elements_modified = False - for key, value in structure.iter_elements(comp): - value, value_modified = transform_postorder(value, transform) - elements.append((key, value)) - elements_modified = elements_modified or value_modified - if elements_modified: - comp = building_blocks.Struct( - elements, container_type=comp.type_signature.python_container - ) - comp, comp_modified = transform(comp) - return comp, comp_modified or elements_modified - elif isinstance(comp, building_blocks.Call): - fn, fn_modified = transform_postorder(comp.function, transform) - if comp.argument is not None: - arg, arg_modified = transform_postorder(comp.argument, transform) - else: - arg, arg_modified = (None, False) - if fn_modified or arg_modified: - comp = building_blocks.Call(fn, arg) - comp, comp_modified = transform(comp) - return comp, comp_modified or fn_modified or arg_modified - elif isinstance(comp, building_blocks.Lambda): - result, result_modified = transform_postorder(comp.result, transform) - if result_modified: - comp = building_blocks.Lambda( - comp.parameter_name, comp.parameter_type, result - ) - comp, comp_modified = transform(comp) - return comp, comp_modified or result_modified - elif isinstance(comp, building_blocks.Block): - variables = [] - variables_modified = False - for key, value in comp.locals: - value, value_modified = transform_postorder(value, transform) - variables.append((key, value)) - variables_modified = variables_modified or value_modified - result, result_modified = transform_postorder(comp.result, transform) - if variables_modified or result_modified: - comp = building_blocks.Block(variables, result) - comp, comp_modified = transform(comp) - return comp, comp_modified or variables_modified or result_modified - else: - raise NotImplementedError( - 'Unrecognized computation building block: {}'.format(str(comp)) - ) - - -TransformReturnType = tuple[building_blocks.ComputationBuildingBlock, bool] - - -def transform_preorder( - comp: building_blocks.ComputationBuildingBlock, - transform: Callable[ - [building_blocks.ComputationBuildingBlock], TransformReturnType - ], -) -> TransformReturnType: - """Walks the AST of `comp` preorder, calling `transform` on the way down. - - Notice that this function will stop walking the tree when its transform - function modifies a node; this is to prevent the caller from unexpectedly - kicking off an infinite recursion. For this purpose the transform function - must identify when it has transformed the structure of a building block; if - the structure of the building block is modified but `False` is returned as - the second element of the tuple returned by `transform`, `transform_preorder` - may result in an infinite recursion. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` to be - transformed in a preorder fashion. - transform: Transform function to be applied to the nodes of `comp`. Must - return a two-tuple whose first element is a - `building_blocks.ComputationBuildingBlock` and whose second element is a - Boolean. If the computation which is passed to `comp` is returned in a - modified state, must return `True` for the second element. This Boolean - controls whether or not to stop traversing the tree under `comp`; if this - Bool is `True`, `transform_preorder` will not traverse this subtree. - - Returns: - A two-tuple, whose first element is modified version of `comp`, and - whose second element is a Boolean indicating whether `comp` was transformed - during the walk. - - Raises: - TypeError: If the argument types don't match those specified above. - """ - - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - inner_comp, modified = transform(comp) - if modified: - return inner_comp, modified - if isinstance( - inner_comp, - ( - building_blocks.CompiledComputation, - building_blocks.Data, - building_blocks.Intrinsic, - building_blocks.Literal, - building_blocks.Placement, - building_blocks.Reference, - ), - ): - return inner_comp, modified - elif isinstance(inner_comp, building_blocks.Lambda): - transformed_result, result_modified = transform_preorder( - inner_comp.result, transform - ) - if not (modified or result_modified): - return inner_comp, False - return ( - building_blocks.Lambda( - inner_comp.parameter_name, - inner_comp.parameter_type, - transformed_result, - ), - True, - ) - elif isinstance(inner_comp, building_blocks.Struct): - elements_modified = False - elements = [] - for name, val in structure.iter_elements(inner_comp): - result, result_modified = transform_preorder(val, transform) - elements_modified = elements_modified or result_modified - elements.append((name, result)) - if not (modified or elements_modified): - return inner_comp, False - return building_blocks.Struct(elements), True - elif isinstance(inner_comp, building_blocks.Selection): - transformed_source, source_modified = transform_preorder( - inner_comp.source, transform - ) - if not (modified or source_modified): - return inner_comp, False - return ( - building_blocks.Selection( - transformed_source, inner_comp.name, inner_comp.index - ), - True, - ) - elif isinstance(inner_comp, building_blocks.Call): - transformed_fn, fn_modified = transform_preorder( - inner_comp.function, transform - ) - if inner_comp.argument is not None: - transformed_arg, arg_modified = transform_preorder( - inner_comp.argument, transform - ) - else: - transformed_arg = None - arg_modified = False - if not (modified or fn_modified or arg_modified): - return inner_comp, False - return building_blocks.Call(transformed_fn, transformed_arg), True - elif isinstance(inner_comp, building_blocks.Block): - transformed_variables = [] - values_modified = False - for key, value in inner_comp.locals: - transformed_value, value_modified = transform_preorder(value, transform) - transformed_variables.append((key, transformed_value)) - values_modified = values_modified or value_modified - transformed_result, result_modified = transform_preorder( - inner_comp.result, transform - ) - if not (modified or values_modified or result_modified): - return inner_comp, False - return ( - building_blocks.Block(transformed_variables, transformed_result), - True, - ) - else: - raise NotImplementedError( - 'Unrecognized computation building block: {}'.format(str(inner_comp)) - ) - - -def transform_postorder_with_symbol_bindings(comp, transform, symbol_tree): - """Uses symbol binding hooks to execute transformations. - - `transform_postorder_with_symbol_bindings` hooks into the preorder traversal - that is defined by walking down the tree to its leaves, using - the variable bindings along this path to push information onto - the given `SymbolTree`. Once we hit the leaves, we walk back up the - tree in a postorder fashion, calling `transform` as we go. - - The transformations `transform_postorder_with_symbol_bindings` executes are - therefore stateful in some sense. Here 'stateful' means that a transformation - executed on a given AST node in general depends on not only the node itself - or its immediate vicinity; possibly there is some global information on which - this transformation depends. `transform_postorder_with_symbol_bindings` is - functional 'from AST to AST' (where `comp` represents the root of an AST) but - not 'from node to node'. - - One important fact to note: there are recursion invariants that - `transform_postorder_with_symbol_bindings` uses the `SymbolTree` data - structure to enforce. In particular, within a `transform` call the following - invariants hold: - - * `symbol_tree.update_payload_with_name` with an argument `name` will call - `update` on the `BoundVariableTracker` in `symbol_tree` which tracks the - value of `ref` active in the current lexical scope. Will raise a - `NameError` if none exists. - - * `symbol_tree.get_payload_with_name` with a string argument `name` will - return the `BoundVariableTracker` instance from `symbol_tree` which - corresponds to the computation bound to the variable `name` in the current - lexical scope. Will raise a `NameError` if none exists. - - These recursion invariants are enforced by the framework, and should be - relied on when designing new transformations that depend on variable - bindings. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` to read - information from or transform. - transform: Python function accepting `comp` and `symbol_tree` arguments and - returning `transformed_comp`. - symbol_tree: Instance of `SymbolTree`, the data structure into which we may - read information about variable bindings, and from which we may read. - - Returns: - Returns a possibly modified version of `comp`, an instance - of `building_blocks.ComputationBuildingBlock`, along with a - Boolean with the value `True` if `comp` was transformed and `False` if it - was not. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(symbol_tree, SymbolTree) - if not callable(transform): - raise TypeError( - 'Argument `transform` to ' - '`transform_postorder_with_symbol_bindings` must ' - 'be callable.' - ) - identifier_seq = itertools.count(start=1) - - def _transform_postorder_with_symbol_bindings_switch( - comp, transform_fn, ctxt_tree, identifier_sequence - ): - """Recursive helper function delegated to after binding comp_id sequence.""" - if isinstance( - comp, - ( - building_blocks.CompiledComputation, - building_blocks.Data, - building_blocks.Intrinsic, - building_blocks.Literal, - building_blocks.Placement, - building_blocks.Reference, - ), - ): - return _traverse_leaf(comp, transform_fn, ctxt_tree, identifier_sequence) - elif isinstance(comp, building_blocks.Selection): - return _traverse_selection( - comp, transform, ctxt_tree, identifier_sequence - ) - elif isinstance(comp, building_blocks.Struct): - return _traverse_tuple(comp, transform, ctxt_tree, identifier_sequence) - elif isinstance(comp, building_blocks.Call): - return _traverse_call(comp, transform, ctxt_tree, identifier_sequence) - elif isinstance(comp, building_blocks.Lambda): - return _traverse_lambda(comp, transform, ctxt_tree, identifier_sequence) - elif isinstance(comp, building_blocks.Block): - return _traverse_block(comp, transform, ctxt_tree, identifier_sequence) - else: - raise NotImplementedError( - 'Unrecognized computation building block: {}'.format(str(comp)) - ) - - def _traverse_leaf(comp, transform, context_tree, identifier_seq): - """Helper function holding traversal logic for leaf nodes.""" - _ = next(identifier_seq) - return transform(comp, context_tree) - - def _traverse_selection(comp, transform, context_tree, identifier_seq): - """Helper function holding traversal logic for selection nodes.""" - _ = next(identifier_seq) - source, source_modified = _transform_postorder_with_symbol_bindings_switch( - comp.source, transform, context_tree, identifier_seq - ) - if source_modified: - # Normalize selection to index based on the type signature of the - # original source. The new source may not have names present. - if comp.index is not None: - index = comp.index - else: - index = structure.name_to_index_map(comp.source.type_signature)[ - comp.name - ] - comp = building_blocks.Selection(source, index=index) - comp, comp_modified = transform(comp, context_tree) - return comp, comp_modified or source_modified - - def _traverse_tuple(comp, transform, context_tree, identifier_seq): - """Helper function holding traversal logic for tuple nodes.""" - _ = next(identifier_seq) - elements = [] - elements_modified = False - for key, value in structure.iter_elements(comp): - value, value_modified = _transform_postorder_with_symbol_bindings_switch( - value, transform, context_tree, identifier_seq - ) - elements.append((key, value)) - elements_modified = elements_modified or value_modified - if elements_modified: - comp = building_blocks.Struct(elements) - comp, comp_modified = transform(comp, context_tree) - return comp, comp_modified or elements_modified - - def _traverse_call(comp, transform, context_tree, identifier_seq): - """Helper function holding traversal logic for call nodes.""" - _ = next(identifier_seq) - fn, fn_modified = _transform_postorder_with_symbol_bindings_switch( - comp.function, transform, context_tree, identifier_seq - ) - if comp.argument is not None: - arg, arg_modified = _transform_postorder_with_symbol_bindings_switch( - comp.argument, transform, context_tree, identifier_seq - ) - else: - arg, arg_modified = (None, False) - if fn_modified or arg_modified: - comp = building_blocks.Call(fn, arg) - comp, comp_modified = transform(comp, context_tree) - return comp, comp_modified or fn_modified or arg_modified - - def _traverse_lambda(comp, transform, context_tree, identifier_seq): - """Helper function holding traversal logic for lambda nodes.""" - comp_id = next(identifier_seq) - context_tree.drop_scope_down(comp_id) - context_tree.ingest_variable_binding(name=comp.parameter_name, value=None) - result, result_modified = _transform_postorder_with_symbol_bindings_switch( - comp.result, transform, context_tree, identifier_seq - ) - context_tree.walk_to_scope_beginning() - if result_modified: - comp = building_blocks.Lambda( - comp.parameter_name, comp.parameter_type, result - ) - comp, comp_modified = transform(comp, context_tree) - context_tree.pop_scope_up() - return comp, comp_modified or result_modified - - def _traverse_block(comp, transform, context_tree, identifier_seq): - """Helper function holding traversal logic for block nodes.""" - comp_id = next(identifier_seq) - context_tree.drop_scope_down(comp_id) - variables = [] - variables_modified = False - for key, value in comp.locals: - value, value_modified = _transform_postorder_with_symbol_bindings_switch( - value, transform, context_tree, identifier_seq - ) - context_tree.ingest_variable_binding(name=key, value=value) - variables.append((key, value)) - variables_modified = variables_modified or value_modified - result, result_modified = _transform_postorder_with_symbol_bindings_switch( - comp.result, transform, context_tree, identifier_seq - ) - context_tree.walk_to_scope_beginning() - if variables_modified or result_modified: - comp = building_blocks.Block(variables, result) - comp, comp_modified = transform(comp, context_tree) - context_tree.pop_scope_up() - return comp, comp_modified or variables_modified or result_modified - - return _transform_postorder_with_symbol_bindings_switch( - comp, transform, symbol_tree, identifier_seq - ) - - -class BoundVariableTracker(metaclass=abc.ABCMeta): - """Abstract class representing a mutable variable binding.""" - - def __init__(self, name, value): - """Initializes `BoundVariableTracker`. - - The initializer is likely to be overwritten by subclasses in order to - attach more state to the `BoundVariableTracker`. Each of them must - satisfy the same interface, however. This is simply because the - `BoundVariableTracker` represents a variable binding in a TFF AST; - no more information is avaiable to it than the `name`-`value` pair - being bound together. - - Args: - name: String name of variable to be bound. - value: Value to bind to this name. Can be instance of - `building_blocks.ComputationBuildingBlock` if this - `BoundVariableTracker` represents a concrete binding to a variable (e.g. - in a block locals declaration), or `None`, if this - `BoundVariableTracker` represents merely a variable declaration (e.g. in - a lambda). - """ - py_typecheck.check_type(name, str) - if value is not None: - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - self.name = name - self.value = value - - def update(self, value=None): - """Defines the way information is read into this node. - - Defaults to no-op. - - Args: - value: Similar to `value` argument in initializer. - """ - del value # Unused - - @abc.abstractmethod - def __str__(self): - """Abstract string method required as context tree will delegate.""" - - def __eq__(self, other): - """Base class equality checks names and values equal.""" - # TODO: b/130890785 - Delegate value-checking to - # `building_blocks.ComputationBuildingBlock`. - if self is other: - return True - if not isinstance(other, BoundVariableTracker): - return NotImplemented - if self.name != other.name: - return False - if isinstance( - self.value, building_blocks.ComputationBuildingBlock - ) and isinstance(other.value, building_blocks.ComputationBuildingBlock): - return ( - self.value.compact_representation() - == other.value.compact_representation() - and self.value.type_signature.is_equivalent_to( - other.value.type_signature - ) - ) - return self.value is other.value - - def __ne__(self, other): - """Implementing __ne__ to enforce in Python2 the Python3 standard.""" - return not self == other - - -class _BeginScopePointer(BoundVariableTracker): - """Sentinel representing the beginning of a scope defined by an AST node.""" - - def __init__(self, name=None, value=None): - if name is not None or value is not None: - raise ValueError( - "Please don't pass a name or value to " - '_BeginScopePointer; it will simply be ignored.' - ) - super().__init__('BeginScope', None) - - def update(self, value=None): - del value # Unused. - raise RuntimeError("We shouldn't be trying to update the outer context.") - - def __str__(self): - return self.name - - def __eq__(self, other): - """Returns `True` iff `other` is also a `_BeginScopePointer`. - - Args: - other: Value for equality comparison. - - Returns: - Returns true iff `other` is also an instance of `_BeginScopePointer`. - """ - # Using explicit type comparisons here to prevent a subclass from passing. - # pylint: disable=unidiomatic-typecheck - return type(other) is _BeginScopePointer - # pylint: enable=unidiomatic-typecheck - - -class SymbolTree: - """Data structure to hold variable bindings as we walk an AST. - - `SymbolTree` is designed to be constructed and mutatated as we traverse an - AST, maintaining a pointer to an active node representing the variable - bindings we currently have available as we walk the AST. - - `SymbolTree` is a hierarchical tree-like data structure. Its levels - correspond to nodes in the TFF AST it is tracking, meaning that walking into - or out of a scope-defining TFF node (a block or lambda) corresponds to - moving up or down a level in the `SymbolTree`. Block constructs (a.k.a. - the let statement) binds variables sequentially, and this sequential binding - corresponds to variables bound at the same level of the `SymbolTree`. - - Each instance of the node class can be used at most once in the symbol tree, - as checked by memory location. This disallows circular tree structures that - could cause an infinite loop in recursive equality testing or printing. - """ - - def __init__(self, payload_type: type[BoundVariableTracker]): - """Initializes `SymbolTree` with its payload type. - - Args: - payload_type: Class which subclasses BoundVariableTracker; the type of - payloads to be constructed and held in this SymbolTree. - """ - initial_node = SequentialBindingNode(_BeginScopePointer()) - self.active_node = initial_node - self.payload_type = payload_type - self._node_ids = {id(initial_node): 1} - - def get_payload_with_name(self, name): - """Returns payload corresponding to `name` in active variable bindings. - - Note that this method obeys `dict.get`-like semantics; instead of raising - when asked to address an unbound name, it simply returns `None`. - - Args: - name: String name to find in currently active context. - - Returns: - Returns instance of `BoundVariableTracker` corresponding to `name` - in context represented by `active_comp`, or `None` if the requested - name is unbound in the current context. - """ - py_typecheck.check_type(name, str) - comp = typing.cast(SequentialBindingNode, self.active_node) - while comp.parent is not None or comp.older_sibling is not None: - if name == comp.payload.name: - return comp.payload - if comp.older_sibling is not None: - comp = comp.older_sibling - elif comp.parent is not None: - comp = comp.parent - return None - - def get_higher_payloads_with_value(self, value, equal_fn=None): - """Returns payloads above `active_node` whose `value` is equal to `value`. - - Args: - value: The value to test. - equal_fn: The optional function to use to determine equality, if `None` is - specified `operator.is_` is used. - """ - payloads = [] - if equal_fn is None: - equal_fn = operator.is_ - node = typing.cast(SequentialBindingNode, self.active_node) - while node.parent is not None or node.older_sibling is not None: - if node.payload.value is not None and equal_fn(value, node.payload.value): - payloads.append(node.payload) - if node.older_sibling is not None: - node = node.older_sibling - elif node.parent is not None: - node = node.parent - return payloads - - def update_payload_with_name(self, name): - """Calls `update` if `name` is found among the available symbols. - - If there is no such available symbol, simply does nothing. - - Args: - name: A string; generally, this is the variable a walker has encountered - in a TFF AST, and which it is relying on `SymbolTable` to address - correctly. - - Raises: - ValueError: If `name` is not found among the bound names currently - available in `self`. - """ - py_typecheck.check_type(name, str) - comp = typing.cast(SequentialBindingNode, self.active_node) - while comp.parent is not None or comp.older_sibling is not None: - if name == comp.payload.name: - comp.payload.update(name) - return - if comp.older_sibling is not None: - comp = comp.older_sibling - elif comp.parent is not None: - comp = comp.parent - raise ValueError( - "The name '{}' is not available in '{}'.".format(name, self) - ) - - def walk_to_scope_beginning(self): - """Walks `active_node` back to the sentinel node beginning current scope. - - `walk_to_scope_beginning` resolves the issue of scope at a node which - introduces scope in the following manner: each of these nodes (for instance, - a `building_blocks.Lambda`) corresponds to a sentinel value of - the `_BeginScopePointer` class, ensuring that these nodes do not have access - to - scope that is technically not available to them. That is, we conceptualize - the node corresponding to `(x -> x)` as existing in the scope outside of the - binding of `x`, and therefore is unable to reference `x`. However, these - nodes can walk down their variable declarations via - `walk_down_one_variable_binding` in order to inspect these declarations and - perhaps execute some logic based on them. - """ - scope_sentinel = _BeginScopePointer() - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - assert self.active_node is not None - while self.active_node.payload != scope_sentinel: - self.active_node = self.active_node.older_sibling - assert self.active_node is not None - - def pop_scope_up(self): - """Moves `active_node` up one level in the `SymbolTree`. - - Raises: - Raises ValueError if we are already at the highest level. - """ - self.walk_to_scope_beginning() - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - if self.active_node.parent: - self.active_node = self.active_node.parent - else: - raise ValueError( - 'You have tried to pop out of the highest level in this `SymbolTree`.' - ) - - def drop_scope_down(self, comp_id): - """Constructs a new scope level for `self`. - - Scope levels in `SymbolTree` correspond to scope-introducing nodes in TFF - ASTs; that is, either `building_blocks.Block` or - `building_blocks.Lambda` nodes. Inside of these levels, - variables are bound in sequence. The implementer of a transformation - function needing to interact with scope should never need to explicitly walk - the scope levels `drop_scope_down` constructs; `drop_scope_down` is simply - provided - for ease of exposing to a traversal function. - - Args: - comp_id: Integer representing a unique key for the - `building_blocks.ComputationBuildingBlock` which is defines this scope. - Used to differentiate between scopes which both branch from the same - point in the tree. - """ - py_typecheck.check_type(comp_id, int) - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - if self.active_node.children.get(comp_id) is None: - node = SequentialBindingNode(_BeginScopePointer()) - self._add_child(comp_id, node) - self._move_to_child(comp_id) - else: - self._move_to_child(comp_id) - - def walk_down_one_variable_binding(self): - """Moves `active_node` to the younger sibling of the current active node. - - This action represents walking from one variable binding in the - `SymbolTree` to the next, sequentially. - - If there is no such variable binding, then the lower bound variables must - be accessed via `drop_scope_down`. - - Raises: - Raises ValueError if there is no such available variable binding. - """ - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - if self.active_node.younger_sibling: - self.active_node = self.active_node.younger_sibling - else: - raise ValueError( - 'You have tried to move to a nonexistent variable binding in {}' - .format(self) - ) - - def ingest_variable_binding(self, name, value): - """Constructs or updates node in symbol tree as AST is walked. - - Passes `name` and `value` onto the symbol tree's node constructor, with - `mode` determining how the node being constructed or updated - relates to the symbol tree's `active_node`. - - If there is no preexisting node in the symbol tree bearing the - requested relationship to the active node, a new one will be constructed and - initialized. If there is an existing node, `ingest_variable_binding` checks - that this node has the correct `payload.name`, and overwrites its - `payload.value` with the `value` argument. - - Args: - name: The string name of the `CompTracker` instance we are constructing or - updating. - value: Instance of `building_blocks.ComputationBuildingBlock` or `None`, - as in the `value` to pass to symbol tree's node payload constructor. - - Raises: - ValueError: If we are passed a name-mode pair such that a - preexisting node in the symbol tree bears this relationship with - the active node, but has a different name. This is an indication - that either a transformation has failed to happen in the symbol tree - or that we have a symbol tree instance that does not match the - computation we are currently processing. - """ - if (name is None or not name) and value is None: - return - py_typecheck.check_type(name, str) - if value is not None: - py_typecheck.check_type(value, building_blocks.ComputationBuildingBlock) - node = SequentialBindingNode(self.payload_type(name=name, value=value)) - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - if self.active_node.younger_sibling is None: - self._add_younger_sibling(node) - self.walk_down_one_variable_binding() - else: - if self.active_node.younger_sibling.payload.name != name: - raise ValueError( - 'You have a mismatch between your symbol tree and the ' - 'computation you are trying to process; your symbol tree is {} ' - 'and you are looking for a BoundVariableTracker with name {} ' - 'and value {}'.format(self, name, value) - ) - self.walk_down_one_variable_binding() - self.active_node.payload.value = value - - def _add_younger_sibling(self, comp_tracker): - """Appends comp as younger sibling of current `active_node`.""" - py_typecheck.check_type(comp_tracker, SequentialBindingNode) - if self._node_ids.get(id(comp_tracker)): - raise ValueError( - 'Each instance of {} can only appear once in a given symbol tree.' - .format(self.payload_type) - ) - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - if self.active_node.younger_sibling is not None: - raise ValueError('Ambiguity in adding a younger sibling') - comp_tracker.set_older_sibling(self.active_node) - self.active_node.set_younger_sibling(comp_tracker) - self._node_ids[id(comp_tracker)] = 1 - - def _add_child(self, constructing_comp_id, comp_tracker): - """Writes `comp_tracker` to children of active node. - - Each `SequentialBindingNode` keeps a `dict` of its children; `_add_child` - updates the value of this `dict` with key `constructing_comp_id` to be - `comp_tracker`. - - Notice that `constructing_comp_id` is simply a way of addressing the - children in this dict; it is not necessarily globally unique, as long - as it is sufficient to address child scopes. - - Args: - constructing_comp_id: Key to identify child being constructed from the - parent scope. - comp_tracker: Instance of `SequentialBindingNode`, the node to add as a - child of `active_node`. - """ - py_typecheck.check_type(comp_tracker, SequentialBindingNode) - if self._node_ids.get(id(comp_tracker)): - raise ValueError( - 'Each node can only appear once in a given' - 'symbol tree. You have tried to add {} ' - 'twice.'.format(comp_tracker.payload) - ) - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - comp_tracker.set_parent(self.active_node) - self.active_node.add_child(constructing_comp_id, comp_tracker) - self._node_ids[id(comp_tracker)] = 1 - - def _move_to_child(self, comp_id): - """Moves `active_node` to child of current active node with key `comp_id`. - - Args: - comp_id: Integer representing the position of the child we wish to update - `active_node` to point to in a preorder traversal of the AST. - - Raises: - ValueError: If the active node has no child with the correct id. - """ - self.active_node = typing.cast(SequentialBindingNode, self.active_node) - if self.active_node.children.get(comp_id) is not None: - self.active_node = self.active_node.get_child(comp_id) - else: - raise ValueError('You have tried to move to a nonexistent child.') - - def _equal_under_node(self, self_node, other_node): - """Recursive helper function to check equality of `SymbolTree`s.""" - if self_node is None and other_node is None: - return True - if self_node is None or other_node is None: - return False - if self_node.payload != other_node.payload: - return False - if len(self_node.children) != len(other_node.children): - return False - for (_, val_1), (_, val_2) in zip( - self_node.children.items(), other_node.children.items() - ): - # keys not compared to avoid coupling walking logic to `SymbolTree`. - if not self._equal_under_node(val_1, val_2): - return False - return self._equal_under_node( - self_node.younger_sibling, other_node.younger_sibling - ) - - def __eq__(self, other): - """Walks to root of `self` and `other` before testing equality of subtrees. - - Args: - other: Instance of `SymbolTree` to test for equality with `self`. - - Returns: - Returns `True` if and only if `self` and `other` are the same - structurally (each node has the same number of children and siblings) and - each node of `self` compares as equal with the node in the corresponding - position of `other`. - """ - if self is other: - return True - if not isinstance(other, SymbolTree): - return NotImplemented - self_root = _walk_to_root(self.active_node) - other_root = _walk_to_root(other.active_node) - return self._equal_under_node(self_root, other_root) - - def __ne__(self, other): - return not self == other - - def _string_under_node(self, node) -> str: - """Rescursive helper function to generate string reps of `SymbolTree`s.""" - py_typecheck.check_type(node, SequentialBindingNode) - if node is self.active_node: - active_node_indicator = '*' - else: - active_node_indicator = '' - symbol_tree_string = '[' + str(node.payload) + active_node_indicator + ']' - if node.children: - symbol_tree_string += '->{' - for _, child_node in node.children.items(): - if not child_node.older_sibling: - symbol_tree_string += '(' - symbol_tree_string += self._string_under_node(child_node) - symbol_tree_string += '),(' - symbol_tree_string = symbol_tree_string[:-2] - symbol_tree_string += '}' - if node.younger_sibling: - symbol_tree_string += '-' + self._string_under_node(node.younger_sibling) - return symbol_tree_string - - def __str__(self): - """Generates a string representation of this `SymbolTree`. - - First we walk up to the root node, then we walk down - the tree generating string rep of this symbol tree. - - Returns: - Returns a string representation of the current `SymbolTree`, with - the node labeled the active node identified with a *. - """ - node = self.active_node - root_node = _walk_to_root(node) - return self._string_under_node(root_node) - - -def _walk_to_root(node): - while node.parent is not None or node.older_sibling is not None: - while node.older_sibling is not None: - node = node.older_sibling - while node.parent is not None: - node = node.parent - return node - - -class SequentialBindingNode: - """Represents a node in a context tree with sequential-binding semantics. - - `SequentialBindingNode`s are designed to be constructed and pushed into - a context tree as an AST representing a given computation is walked. - - Each `SequentialBindingNode` holds as payload a variable binding in the AST. - The node-node relationships encoded by the `SequentialBindingNode` data - structure determine how the context tree must be walked in order to resolve - variables and track their values in the AST. - - Parent-child relationships represent relationships between levels of the AST, - meaning, moving through an AST node which defines a variable scope in preorder - corresponds to moving from a `SequentialBindingNode` to one of its children, - and moving through such a node postorder corresponds to moving from a - `SequentialBindingNode` to its parent. - - Sibling-sibling relationships are particular to sequential binding of - variables in `building_blocks.Block` constructs; binding - a new variable in such a construct corresponds to moving from a - `SequentialBindingNode` to its (unique) younger sibling. - """ - - def __init__(self, payload): - """Initializes `SequentialBindingNode`. - - Args: - payload: Instance of BoundVariableTracker representing the payload of this - node. - """ - py_typecheck.check_type(payload, BoundVariableTracker) - self.payload = payload - self._children = collections.OrderedDict() - self._parent = None - self._older_sibling = None - self._younger_sibling = None - - @property - def parent(self): - return self._parent - - @property - def children(self): - return self._children - - @property - def older_sibling(self): - return self._older_sibling - - @property - def younger_sibling(self): - return self._younger_sibling - - def set_parent(self, node): - """Sets the _parent scope of `self` to the binding embodied by `node`. - - This method should not be assumed to be efficient. - - Args: - node: Instance of `SequentialBindingNode` to set as parent of `self`. - """ - py_typecheck.check_type(node, SequentialBindingNode) - self._parent = node - - def set_older_sibling(self, node): - """Sets the older sibling scope of `self` to `node`. - - This method should not be assumed to be efficient. - - Args: - node: Instance of `SequentialBindingNode` to set as older sibling of - `self`. - """ - py_typecheck.check_type(node, SequentialBindingNode) - self._older_sibling = node - - def set_younger_sibling(self, node): - """Sets the younger sibling scope of `self` to `node`. - - This corresponds to binding a new variable in a - `building_blocks.Block` construct. - - This method should not be assumed to be efficient. - - Args: - node: Instance of `SequentialBindingNode` representing this new binding. - """ - py_typecheck.check_type(node, SequentialBindingNode) - self._younger_sibling = node - - def add_child(self, comp_id, node): - """Sets the child scope of `self` indexed by `comp_id` to `node`. - - This corresponds to encountering a node in a TFF AST which defines a - variable scope. - - If a child with this `comp_id` already exists, it is replaced, as in a - `dict`. - - Args: - comp_id: The identifier of the computation generating this scope. - node: Instance of `SequentialBindingNode` representing this new binding. - """ - py_typecheck.check_type(node, SequentialBindingNode) - self._children[comp_id] = node - - def get_child(self, comp_id): - """Returns the child of `self` identified by `comp_id` if one exists. - - Args: - comp_id: Integer used to address child of `self` by position of - corresponding AST node in a preorder traversal of the AST. - - Returns: - Instance of `SequentialBindingNode` if an appropriate child of `self` - exists, or `None`. - """ - return self._children.get(comp_id) - - -def list_comp_names(comp): - """Canonical list of string representations of nodes in `comp`. - - Used as a helper function to generate static name-to-index mappings. - - Args: - comp: The root of the AST for which we wish to generate a list of string - representations of all nodes. - - Returns: - names: Python `list` of string representations of nodes under `comp`. - This list is generated by walking the AST of `comp` in postorder fashion - and thus is deterministic. - """ - names = [] - - def _string_rep(inner_comp): - names.append(str(inner_comp)) - return inner_comp - - transform_postorder(comp, _string_rep) - return names - - -class ReferenceCounter(BoundVariableTracker): - """Data container to track number References to a variable in an AST. - - - Attributes: - name: The string name representing the variable whose binding is represented - by an instance of `ReferenceCounter`. - value: The value bound to `name`. Can be an instance of - `building_blocks.ComputationBuildingBlock` or None if this binding is - simply a placeholder, e.g. in a Lambda. - count: An integer tracking how many times the variable an instance of - `ReferenceCounter` represents is referenced in a TFF AST. - """ - - def __init__(self, name, value): - super().__init__(name, value) - self.count = 0 - - def update(self, value=None): - del value # Unused. - self.count += 1 - - def __str__(self): - return 'Instance count: {}; value: {}; name: {}.'.format( - self.count, self.value, self.name - ) - - def __repr__(self): - return str(self) - - def __eq__(self, other): - if self is other: - return True - if not isinstance(other, ReferenceCounter): - return NotImplemented - if not super().__eq__(other): - return False - return self.count == other.count - - -def get_count_of_references_to_variables(comp): - """Returns `SymbolTree` counting references to each bound variable in `comp`. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` representing - the root of the AST for which we want to read total reference counts by - context. - - Returns: - An instance of `SymbolTree` representing root of context tree populated with - `ReferenceCounter`s which contain the number of times each variable bound by - a `building_blocks.Lambda` or `building_blocks.Block` are referenced in - their computation's body. - """ - - reference_counter = SymbolTree(ReferenceCounter) - - def _should_transform(comp, context_tree): - del context_tree # Unused - return isinstance(comp, building_blocks.Reference) - - def transform_fn(comp, context_tree): - if _should_transform(comp, context_tree): - context_tree.update_payload_with_name(comp.name) - return comp, False - - transform_postorder_with_symbol_bindings( - comp, transform_fn, reference_counter - ) - return reference_counter - - -def get_unique_names(comp): - """Returns the unique names bound or referred to in `comp`.""" - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - names = set() - - def _update(comp): - if isinstance(comp, building_blocks.Block): - names.update([name for name, _ in comp.locals]) - elif isinstance(comp, building_blocks.Lambda): - if comp.parameter_type is not None: - names.add(comp.parameter_name) - elif isinstance(comp, building_blocks.Reference): - names.add(comp.name) - return comp, False - - transform_postorder(comp, _update) - return names - - -def get_map_of_unbound_references( - comp: building_blocks.ComputationBuildingBlock, -) -> dict[building_blocks.ComputationBuildingBlock, set[str]]: - """Gets a Python `dict` of unbound references in `comp`, keyed by Python `id`. - - Computations that are equal will have the same collections of unbounded - references, so it is safe to use `comp` as the key for this `dict` even though - a given computation may appear in many positions in the AST. - - Args: - comp: The computation building block to parse. - - Returns: - A Python `dict` of elements where keys are the computations in `comp` and - values are a Python `set` of the names of the unbound references in the - subtree of that computation. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - references = {} - - def _update(comp): - """Updates the Python dict of references.""" - if isinstance(comp, building_blocks.Reference): - references[comp] = set((comp.name,)) - elif isinstance(comp, building_blocks.Block): - references[comp] = set() - names = [] - for name, variable in comp.locals: - elements = references[variable] - references[comp].update([e for e in elements if e not in names]) - names.append(name) - elements = references[comp.result] - references[comp].update([e for e in elements if e not in names]) - elif isinstance(comp, building_blocks.Call): - elements = references[comp.function].copy() - if comp.argument is not None: - elements.update(references[comp.argument]) - references[comp] = elements - elif isinstance(comp, building_blocks.Lambda): - elements = references[comp.result] - references[comp] = set([e for e in elements if e != comp.parameter_name]) - elif isinstance(comp, building_blocks.Selection): - references[comp] = references[comp.source] - elif isinstance(comp, building_blocks.Struct): - elements = [references[e] for e in comp] - references[comp] = set(itertools.chain.from_iterable(elements)) - else: - references[comp] = set() - return comp, False - - transform_postorder(comp, _update) - return references - - -class TransformSpec(metaclass=abc.ABCMeta): - """Base class to express the should_transform/transform interface.""" - - def __init__(self, global_transform=False): - self._global_transform = global_transform - - @property - def global_transform(self): - return self._global_transform - - @abc.abstractmethod - def should_transform(self, comp): - pass - - @abc.abstractmethod - def transform(self, comp): - pass diff --git a/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py b/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py deleted file mode 100644 index 8b77582d58..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py +++ /dev/null @@ -1,2008 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.proto.v0 import computation_pb2 -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_block_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -def _construct_complex_symbol_tree(): - """Constructs complex context tree for mutation testing.""" - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - for _ in range(2): - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - symbol_tree.drop_scope_down(0) - for _ in range(2): - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - return symbol_tree - - -def _construct_simple_block(type_signature): - """Constructs minimal example of LET construct in TFF.""" - test_arg = building_blocks.Literal(1, type_signature) - result = building_blocks.Reference('x', test_arg.type_signature) - simple_block = building_blocks.Block([('x', test_arg)], result) - return simple_block - - -class UpdatableTracker(transformation_utils.BoundVariableTracker): - - def __init__(self, name, value): - super().__init__(name, value) - self.count = 0 - - def update(self, comp): - self.count += 1 - - def __str__(self): - return '{Count: ' + str(self.count) + '}' - - def __eq__(self, other): - return id(self) == id(other) - - -class FakeTracker(transformation_utils.BoundVariableTracker): - - def update(self, comp=None): - pass - - def __str__(self): - return self.name - - def __eq__(self, other): - return isinstance(other, FakeTracker) - - -def fake_tracker_node_factory(): - return transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ) - - -class TrivialBoundVariableTracker(transformation_utils.BoundVariableTracker): - - def update(self, comp): - pass - - def __str__(self): - return 'TrivialBoundVariableTracker' - - -def _construct_trivial_instance_of_all_computation_building_blocks(): - cbb_list = [] - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array(1, np.int32) - ) - ref_to_x = building_blocks.Reference('x', np.int32) - cbb_list.append(('reference', ref_to_x)) - lam = building_blocks.Lambda('x', np.int32, ref_to_x) - cbb_list.append(('lambda', lam)) - block = building_blocks.Block([('x', ref_to_x)], lam) - cbb_list.append(('block', block)) - data = building_blocks.Data(any_proto, np.int32) - cbb_list.append(('data', data)) - function_type = computation_types.FunctionType(np.int32, np.int32) - intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', function_type) - cbb_list.append(('intrinsic', intrinsic)) - tff_struct = building_blocks.Struct([ref_to_x]) - cbb_list.append(('struct', tff_struct)) - selection = building_blocks.Selection(tff_struct, index=0) - cbb_list.append(('selection', selection)) - call = building_blocks.Call(lam, ref_to_x) - cbb_list.append(('call', call)) - tensor_type = computation_types.TensorType(np.int32) - proto = computation_factory.create_lambda_identity(tensor_type) - function_type = computation_types.FunctionType(tensor_type, tensor_type) - compiled_comp = building_blocks.CompiledComputation( - proto, type_signature=function_type - ) - cbb_list.append(('compiled_comp', compiled_comp)) - placement = building_blocks.Placement(placements.CLIENTS) - cbb_list.append(('placement', placement)) - return cbb_list - - -def _get_number_of_nodes_via_transform_postorder(comp, predicate=None): - """Returns the number of nodes in `comp` matching `predicate`.""" - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - count = 0 - - def fn(comp): - nonlocal count - if predicate is None or predicate(comp): - count += 1 - return comp, False - - transformation_utils.transform_postorder(comp, fn) - return count - - -def _get_number_of_nodes_via_transform_postorder_with_symbol_bindings( - comp, predicate=None -): - """Returns the number of nodes in `comp` matching `predicate`.""" - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - count = 0 - - def fn(comp, ctxt_tree): - nonlocal count - del ctxt_tree - if predicate is None or predicate(comp): - count += 1 - return comp, False - - transformation_utils.transform_postorder_with_symbol_bindings( - comp, fn, empty_context_tree - ) - - return count - - -def _get_number_of_nodes_via_transform_preorder(comp, predicate=None): - """Returns the number of nodes in `comp` matching `predicate`.""" - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - count = 0 - - def fn(comp): - nonlocal count - if predicate is None or predicate(comp): - count += 1 - return comp, False - - transformation_utils.transform_preorder(comp, fn) - return count - - -class TransformationUtilsTest(parameterized.TestCase): - - def test_transform_postorder_fails_on_none_comp(self): - def transform(comp): - return comp, False - - with self.assertRaises(TypeError): - transformation_utils.transform_postorder(None, transform) - - def test_transform_postorder_fails_on_none_transform(self): - comp = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - transformation_utils.transform_postorder(comp, None) - - def test_transform_postorder_with_lambda_call_selection_and_reference(self): - function_type = computation_types.FunctionType(np.int32, np.int32) - ref = building_blocks.Reference('FEDERATED_arg', [function_type, np.int32]) - fn = building_blocks.Selection(ref, index=0) - arg = building_blocks.Selection(ref, index=1) - call = building_blocks.Call(fn, arg) - comp = building_blocks.Lambda(ref.name, np.int32, call) - self.assertEqual( - str(comp), '(FEDERATED_arg -> FEDERATED_arg[0](FEDERATED_arg[1]))' - ) - - def _transformation_fn_generator(): - n = 0 - while True: - n = n + 1 - - def _fn(x): - intrinsic_type = computation_types.FunctionType( - x.type_signature, x.type_signature - ) - intrinsic = building_blocks.Intrinsic('F{}'.format(n), intrinsic_type) - call = building_blocks.Call(intrinsic, x) - return call, True - - yield _fn - - transformation_fn_sequence = _transformation_fn_generator() - - def tx_fn(x): - return next(transformation_fn_sequence)(x) - - transfomed_comp, modified = transformation_utils.transform_postorder( - comp, tx_fn - ) - self.assertEqual( - transfomed_comp.compact_representation(), - ( - 'F6((FEDERATED_arg ->' - ' F5(F2(F1(FEDERATED_arg)[0])(F4(F3(FEDERATED_arg)[1])))))' - ), - ) - self.assertTrue(modified) - - def test_transform_postorder_with_block_and_data_to_reference(self): - ref = building_blocks.Reference('x', np.int32) - data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - blk = building_blocks.Block([('x', data)], ref) - - def _transformation_fn(comp): - if isinstance(comp, building_blocks.Block): - return building_blocks.Block(comp.locals, data), True - return comp, False - - transformed, modified = transformation_utils.transform_preorder( - blk, _transformation_fn - ) - self.assertTrue(modified) - self.assertEqual(transformed.compact_representation(), '(let x=1 in 1)') - - @parameterized.named_parameters( - _construct_trivial_instance_of_all_computation_building_blocks() - + [( - 'complex_tree', - building_block_test_utils.create_nested_syntax_tree(), - )] - ) - def test_transform_postorder_returns_untransformed(self, comp): - def transform_noop(comp): - return comp, False - - same_comp, modified = transformation_utils.transform_postorder( - comp, transform_noop - ) - self.assertEqual( - same_comp.compact_representation(), comp.compact_representation() - ) - self.assertFalse(modified) - - @parameterized.named_parameters( - _construct_trivial_instance_of_all_computation_building_blocks() - ) - def test_transform_postorder_does_not_construct_new_internal(self, comp): - def transform_noop(comp): - return comp, False - - same_comp, modified = transformation_utils.transform_postorder( - comp, transform_noop - ) - - self.assertEqual(comp, same_comp) - self.assertFalse(modified) - - def test_transform_postorder_hits_all_nodes_once(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - self.assertEqual( - _get_number_of_nodes_via_transform_postorder(complex_ast), 22 - ) - - def test_transform_postorder_walks_to_leaves_in_postorder(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - leaf_name_order = [] - - def transform(comp): - if isinstance(comp, building_blocks.Literal): - leaf_name_order.append(comp.value) - return comp, False - - transformation_utils.transform_postorder(complex_ast, transform) - - self.assertEqual(leaf_name_order, [1, 2, 3, 4, 5, 6, 7, 10, 8, 9, 11]) - - def test_transform_postorder_walks_block_locals_postorder(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - leaf_name_order = [] - - def transform(comp): - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - leaf_name_order.append(name) - return comp, False - - transformation_utils.transform_postorder(complex_ast, transform) - - self.assertEqual(leaf_name_order, ['t', 'u', 'v', 'w', 'x', 'y', 'z']) - - def test_transform_postorder_walks_through_all_internal_nodes_postorder(self): - """Checks `transform_postorder` walks correctly through any internal node. - - This test is split from the one above because it tests extra cases - in `transform_postorder`; in particular, all instances of - `building_blocks.ComputationBuildingBlock` which kick off - recursive calls of `transform_postorder` are exercised in this test, - while only a subset are exercised in the above. For example, if the - logic ingesting a `Call` breaks, this test will fail and the one above - may pass. - """ - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - leaf_name_order = [] - - def transform(comp): - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - leaf_name_order.append(name) - elif isinstance(comp, building_blocks.Literal): - leaf_name_order.append(comp.value) - return comp, False - - transformation_utils.transform_postorder(complex_ast, transform) - postorder_nodes = [ - 1, - 2, - 3, - 4, - 't', - 5, - 6, - 'u', - 7, - 'v', - 10, - 8, - 9, - 'w', - 'x', - 'y', - 'z', - 11, - ] - - self.assertEqual(leaf_name_order, list(postorder_nodes)) - - # TODO: b/113123410 - Add more tests for corner cases of `transform_preorder`. - - def test_transform_postorder_with_symbol_bindings_fails_on_none_comp(self): - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - - def transform(comp, ctxt_tree): - del ctxt_tree - return comp, False - - with self.assertRaises(TypeError): - transformation_utils.transform_postorder_with_symbol_bindings( - None, transform, empty_context_tree - ) - - def test_transform_postorder_with_symbol_bindings_fails_on_none_transform( - self, - ): - empty_symbol_tree = transformation_utils.SymbolTree(FakeTracker) - whimsy_comp = building_blocks.Reference('x', np.int32) - - with self.assertRaises(TypeError): - transformation_utils.transform_postorder_with_symbol_bindings( - whimsy_comp, None, empty_symbol_tree - ) - - def test_transform_postorder_with_symbol_bindings_fails_on_none_symbol_tree( - self, - ): - whimsy_comp = building_blocks.Reference('x', np.int32) - - def transform(comp, ctxt_tree): - del ctxt_tree - return comp, False - - with self.assertRaises(TypeError): - transformation_utils.transform_postorder_with_symbol_bindings( - whimsy_comp, transform, None - ) - - @parameterized.named_parameters( - _construct_trivial_instance_of_all_computation_building_blocks() - + [('complex_ast', building_block_test_utils.create_nested_syntax_tree())] - ) - def test_transform_postorder_with_symbol_bindings_returns_untransformed( - self, comp - ): - def transform_noop(comp, ctxt_tree): - del ctxt_tree - return comp, False - - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - same_comp, _ = ( - transformation_utils.transform_postorder_with_symbol_bindings( - comp, transform_noop, empty_context_tree - ) - ) - self.assertEqual( - same_comp.compact_representation(), comp.compact_representation() - ) - - @parameterized.named_parameters( - _construct_trivial_instance_of_all_computation_building_blocks() - ) - def test_transform_postorder_with_symbol_bindings_does_not_constructs_new_internal_nodes( - self, comp - ): - def transform_noop(comp, ctxt_tree): - del ctxt_tree - return comp, False - - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - same_comp, _ = ( - transformation_utils.transform_postorder_with_symbol_bindings( - comp, transform_noop, empty_context_tree - ) - ) - if not isinstance( - comp, - ( - building_blocks.CompiledComputation, - building_blocks.Data, - building_blocks.Intrinsic, - building_blocks.Literal, - building_blocks.Placement, - building_blocks.Reference, - ), - ): - self.assertEqual(id(comp), id(same_comp)) - - def test_transform_postorder_with_symbol_bindings_hits_all_nodes_once(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - simple_count = _get_number_of_nodes_via_transform_postorder(complex_ast) - with_hooks_count = ( - _get_number_of_nodes_via_transform_postorder_with_symbol_bindings( - complex_ast - ) - ) - - self.assertEqual(with_hooks_count, simple_count) - - @parameterized.named_parameters( - ('reference', building_blocks.Reference), - ('lambda', building_blocks.Lambda), - ('block', building_blocks.Block), - ('data', building_blocks.Data), - ('intrinsic', building_blocks.Intrinsic), - ('struct', building_blocks.Struct), - ('selection', building_blocks.Selection), - ('call', building_blocks.Call), - ('compiled_computation', building_blocks.CompiledComputation), - ('placement', building_blocks.Placement), - ) - def test_transform_postorder_with_symbol_bindings_counts_each_type_correctly( - self, cbb_type - ): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - simple_count = _get_number_of_nodes_via_transform_postorder( - complex_ast, predicate=lambda x: isinstance(x, cbb_type) - ) - with_hooks_count = ( - _get_number_of_nodes_via_transform_postorder_with_symbol_bindings( - complex_ast, predicate=lambda x: isinstance(x, cbb_type) - ) - ) - - self.assertEqual(with_hooks_count, simple_count) - - def test_transform_postorder_hooks_walks_leaves_in_postorder(self): - leaf_order = [] - outer_comp = building_block_test_utils.create_nested_syntax_tree() - - def transform(comp, ctxt_tree): - del ctxt_tree - if isinstance(comp, building_blocks.Literal): - leaf_order.append(comp.value) - return comp, False - - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - transformation_utils.transform_postorder_with_symbol_bindings( - outer_comp, transform, empty_context_tree - ) - self.assertEqual(leaf_order, [1, 2, 3, 4, 5, 6, 7, 10, 8, 9, 11]) - - def test_transform_postorder_hooks_walks_block_locals_postorder(self): - block_locals_order = [] - outer_comp = building_block_test_utils.create_nested_syntax_tree() - - def transform(comp, ctxt_tree): - del ctxt_tree - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - block_locals_order.append(name) - return comp, False - - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - transformation_utils.transform_postorder_with_symbol_bindings( - outer_comp, transform, empty_context_tree - ) - self.assertEqual(block_locals_order, ['t', 'u', 'v', 'w', 'x', 'y', 'z']) - - def test_transform_postorder_hooks_walks_variable_declarations_in_order(self): - variable_binding_order = [] - outer_comp = building_block_test_utils.create_nested_syntax_tree() - - class PreorderHookTracker(transformation_utils.BoundVariableTracker): - - def __init__(self, name, value): - variable_binding_order.append(name) - super().__init__(name, value) - - def update(self, value): - pass - - def __str__(self): - pass - - def __eq__(self, other): - return NotImplemented - - empty_context_tree = transformation_utils.SymbolTree(PreorderHookTracker) - transformation_utils.transform_postorder_with_symbol_bindings( - outer_comp, lambda x, y: (x, False), empty_context_tree - ) - self.assertEqual( - variable_binding_order, ['arg', 'y', 'z', 't', 'u', 'v', 'x', 'w'] - ) - - def test_transform_postorder_hooks_walks_postorder_interleaved(self): - named_node_order = [] - outer_comp = building_block_test_utils.create_nested_syntax_tree() - - def transform(comp, ctxt_tree): - del ctxt_tree - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - named_node_order.append(name) - elif isinstance(comp, building_blocks.Literal): - named_node_order.append(comp.value) - return comp, False - - empty_context_tree = transformation_utils.SymbolTree(FakeTracker) - transformation_utils.transform_postorder_with_symbol_bindings( - outer_comp, transform, empty_context_tree - ) - correct_results = [ - 1, - 2, - 3, - 4, - 't', - 5, - 6, - 'u', - 7, - 'v', - 10, - 8, - 9, - 'w', - 'x', - 'y', - 'z', - 11, - ] - self.assertEqual(named_node_order, correct_results) - - def test_transform_postorder_with_symbol_bindings_binds_lambda_param(self): - result = building_blocks.Reference('x', np.int32) - lam = building_blocks.Lambda('x', np.int32, result) - empty_symbol_tree = transformation_utils.SymbolTree(UpdatableTracker) - value_holder = [] - - def transform(comp, ctxt_tree): - if isinstance(comp, building_blocks.Reference): - ctxt_tree.update_payload_with_name(comp.name) - value_holder.append(ctxt_tree.get_payload_with_name(comp.name)) - return comp, False - - transformation_utils.transform_postorder_with_symbol_bindings( - lam, transform, empty_symbol_tree - ) - - self.assertEqual(value_holder[0].count, 1) - self.assertEqual(value_holder[0].name, 'x') - self.assertIsNone(value_holder[0].value) - - def test_transform_postorder_with_symbol_bindings_binds_single_block_local( - self, - ): - result = building_blocks.Reference('x', np.int32) - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block = building_blocks.Block([('x', arg)], result) - empty_symbol_tree = transformation_utils.SymbolTree(UpdatableTracker) - value_holder = [] - - def transform(comp, ctxt_tree): - if isinstance(comp, building_blocks.Reference): - ctxt_tree.update_payload_with_name(comp.name) - value_holder.append(ctxt_tree.get_payload_with_name(comp.name)) - return comp, False - - transformation_utils.transform_postorder_with_symbol_bindings( - block, transform, empty_symbol_tree - ) - - self.assertEqual(value_holder[0].count, 1) - self.assertEqual(value_holder[0].name, 'x') - self.assertEqual(value_holder[0].value, arg) - - def test_transform_postorder_with_symbol_bindings_binds_sequential_block_locals( - self, - ): - result = building_blocks.Reference('x', np.int32) - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - arg2 = building_blocks.Reference('x', np.int32) - block = building_blocks.Block([('x', arg), ('x', arg2)], result) - empty_symbol_tree = transformation_utils.SymbolTree(UpdatableTracker) - value_holder = [] - - def transform(comp, ctxt_tree): - if isinstance(comp, building_blocks.Reference): - ctxt_tree.update_payload_with_name(comp.name) - value_holder.append(ctxt_tree.get_payload_with_name(comp.name)) - return comp, False - - transformation_utils.transform_postorder_with_symbol_bindings( - block, transform, empty_symbol_tree - ) - - self.assertEqual(value_holder[0].count, 1) - self.assertEqual(value_holder[0].name, 'x') - self.assertEqual(value_holder[0].value, arg) - self.assertEqual(value_holder[1].count, 1) - self.assertEqual(value_holder[1].name, 'x') - self.assertEqual(value_holder[1].value, arg2) - - def test_symbol_tree_initializes(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - self.assertIsInstance( - symbol_tree.active_node.payload, transformation_utils._BeginScopePointer - ) - - def test_symbol_tree_node_reuse_fails(self): - fake_tracker_node_one = transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ) - fake_tracker_node_two = transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ) - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_child(0, fake_tracker_node_one) - symbol_tree._move_to_child(0) - symbol_tree._add_younger_sibling(fake_tracker_node_two) - symbol_tree.walk_down_one_variable_binding() - with self.assertRaisesRegex(ValueError, 'can only appear once'): - symbol_tree._add_child(1, fake_tracker_node_one) - with self.assertRaisesRegex(ValueError, 'can only appear once'): - symbol_tree._add_younger_sibling(fake_tracker_node_one) - - def test_symbol_tree_get_payload_resolves_child_parent_name_conflict(self): - def _construct_symbol_tree(): - """Constructs a symbol tree of the form below. - - Outer Context - | - V - x_tracker - | - V - x_tracker2* - - Returns: - Returns this tree and the payloads used to construct it. - """ - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - x_tracker = FakeTracker('x', None) - symbol_tree._add_child( - 0, transformation_utils.SequentialBindingNode(x_tracker) - ) - symbol_tree._move_to_child(0) - x_tracker2 = FakeTracker('x', None) - symbol_tree._add_child( - 1, transformation_utils.SequentialBindingNode(x_tracker2) - ) - symbol_tree._move_to_child(1) - return symbol_tree, x_tracker, x_tracker2 - - symbol_tree, _, x_tracker2 = _construct_symbol_tree() - self.assertEqual(id(symbol_tree.get_payload_with_name('x')), id(x_tracker2)) - - def test_symbol_tree_get_payload_resolves_sibling_name_conflict(self): - def _construct_symbol_tree(): - """Constructs a symbol tree of the form below. - - Outer Context - | - V - x_tracker - | - x_tracker2* - - Returns: - Returns this tree and the payloads used to construct it. - """ - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - x_tracker = FakeTracker('x', None) - symbol_tree._add_child( - 0, transformation_utils.SequentialBindingNode(x_tracker) - ) - symbol_tree._move_to_child(0) - x_tracker2 = FakeTracker('x', None) - symbol_tree._add_younger_sibling( - transformation_utils.SequentialBindingNode(x_tracker2) - ) - symbol_tree.walk_down_one_variable_binding() - return symbol_tree, x_tracker, x_tracker2 - - symbol_tree, _, x_tracker2 = _construct_symbol_tree() - self.assertEqual(id(symbol_tree.get_payload_with_name('x')), id(x_tracker2)) - - def test_symbol_tree_get_payload_addresses_parent(self): - def _construct_symbol_tree(): - """Constructs a symbol tree of the form below. - - Outer Context - | - V - z_tracker - | - V - x_tracker* - - Returns: - Returns this tree and the payloads used to construct it. - """ - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - z_tracker = FakeTracker('z', None) - symbol_tree._add_child( - 0, transformation_utils.SequentialBindingNode(z_tracker) - ) - symbol_tree._move_to_child(0) - x_tracker = FakeTracker('x', None) - symbol_tree._add_child( - 1, transformation_utils.SequentialBindingNode(x_tracker) - ) - symbol_tree._move_to_child(1) - return symbol_tree, z_tracker, x_tracker - - symbol_tree, z_tracker, _ = _construct_symbol_tree() - self.assertEqual(id(symbol_tree.get_payload_with_name('z')), id(z_tracker)) - - def test_symbol_tree_updates_correct_node_across_siblings(self): - def _construct_symbol_tree(): - r"""Builds symbol tree with the structure below. - - Outer Context - | - V - x_tracker - | - elder_y - | - young_y* - - Returns: - Returns this tree and the `SequentialBindingNode`s - used to construct it. - """ - x_tracker = transformation_utils.SequentialBindingNode( - UpdatableTracker('x', None) - ) - elder_y = transformation_utils.SequentialBindingNode( - UpdatableTracker('y', None) - ) - young_y = transformation_utils.SequentialBindingNode( - UpdatableTracker('y', None) - ) - - complex_symbol_tree = transformation_utils.SymbolTree(UpdatableTracker) - complex_symbol_tree._add_child(4, x_tracker) - complex_symbol_tree._move_to_child(4) - complex_symbol_tree._add_younger_sibling(elder_y) - complex_symbol_tree.walk_down_one_variable_binding() - complex_symbol_tree._add_younger_sibling(young_y) - complex_symbol_tree.walk_down_one_variable_binding() - return complex_symbol_tree, x_tracker, elder_y, young_y - - (complex_symbol_tree, x_tracker, elder_y, young_y) = ( - _construct_symbol_tree() - ) - complex_symbol_tree.update_payload_with_name('x') - complex_symbol_tree.update_payload_with_name('y') - self.assertEqual(x_tracker.payload.count, 1) - self.assertEqual(young_y.payload.count, 1) - self.assertEqual(complex_symbol_tree.get_payload_with_name('x').count, 1) - self.assertEqual(complex_symbol_tree.get_payload_with_name('y').count, 1) - self.assertEqual(elder_y.payload.count, 0) - self.assertIsNone(complex_symbol_tree.get_payload_with_name('z')) - - def test_symbol_tree_updates_correct_node_across_generations(self): - def _construct_symbol_tree(): - r"""Builds symbol tree with the structure below. - - Outer Context - | - V - x_tracker - | - elder_y - / \ - V V - young_y* misdirect_z - - Returns: - Returns this tree and the `SequentialBindingNode`s - used to construct it. - """ - x_tracker = transformation_utils.SequentialBindingNode( - UpdatableTracker('x', None) - ) - elder_y = transformation_utils.SequentialBindingNode( - UpdatableTracker('y', None) - ) - young_y = transformation_utils.SequentialBindingNode( - UpdatableTracker('y', None) - ) - misdirect_z = transformation_utils.SequentialBindingNode( - UpdatableTracker('z', None) - ) - - complex_symbol_tree = transformation_utils.SymbolTree(UpdatableTracker) - complex_symbol_tree.drop_scope_down(4) - complex_symbol_tree._add_younger_sibling(x_tracker) - complex_symbol_tree.walk_down_one_variable_binding() - complex_symbol_tree._add_younger_sibling(elder_y) - complex_symbol_tree.walk_down_one_variable_binding() - complex_symbol_tree.drop_scope_down(5) - complex_symbol_tree._add_younger_sibling(young_y) - complex_symbol_tree.walk_down_one_variable_binding() - complex_symbol_tree.drop_scope_down(6) - complex_symbol_tree._add_younger_sibling(misdirect_z) - complex_symbol_tree.pop_scope_up() - return (complex_symbol_tree, x_tracker, elder_y, young_y, misdirect_z) - - (complex_symbol_tree, x_tracker, elder_y, young_y, misdirect_z) = ( - _construct_symbol_tree() - ) - complex_symbol_tree.update_payload_with_name('x') - complex_symbol_tree.update_payload_with_name('y') - self.assertEqual(x_tracker.payload.count, 1) - self.assertEqual(young_y.payload.count, 1) - self.assertEqual(elder_y.payload.count, 0) - self.assertEqual(complex_symbol_tree.get_payload_with_name('x').count, 1) - self.assertEqual(complex_symbol_tree.get_payload_with_name('y').count, 1) - self.assertIsNone(complex_symbol_tree.get_payload_with_name('z')) - complex_symbol_tree.pop_scope_up() - complex_symbol_tree.update_payload_with_name('y') - complex_symbol_tree.update_payload_with_name('y') - self.assertEqual(elder_y.payload.count, 2) - self.assertEqual(complex_symbol_tree.get_payload_with_name('y').count, 2) - self.assertEqual(misdirect_z.payload.count, 0) - complex_symbol_tree.walk_to_scope_beginning() - self.assertIsNone(complex_symbol_tree.get_payload_with_name('y')) - - def test_typechecking_in_symbol_tree_resolve_methods(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - with self.assertRaises(TypeError): - symbol_tree.get_payload_with_name(0) - with self.assertRaises(TypeError): - symbol_tree.update_payload_with_name(0) - with self.assertRaises(ValueError): - symbol_tree.update_payload_with_name('x') - - def test_symbol_tree_walk_to_scope_beginning_nonempty_scope_moves_to_sentinel( - self, - ): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.ingest_variable_binding('x', None) - fake_tracker_payload = symbol_tree.active_node.payload - symbol_tree.walk_to_scope_beginning() - self.assertIsInstance(fake_tracker_payload, FakeTracker) - self.assertIsInstance( - symbol_tree.active_node.payload, transformation_utils._BeginScopePointer - ) - - def test_symbol_tree_walk_to_scope_beginning_empty_scope_noops(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - begin_scope_node = symbol_tree.active_node - symbol_tree.walk_to_scope_beginning() - self.assertIs(symbol_tree.active_node, begin_scope_node) - - def test_symbol_tree_pop_scope_up_at_top_level_fails(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_child( - 0, - transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ), - ) - with self.assertRaisesRegex(ValueError, 'highest level'): - symbol_tree.pop_scope_up() - - def test_symbol_tree_pop_scope_up_one_level_tree_succeeds(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.ingest_variable_binding('x', None) - symbol_tree.drop_scope_down(0) - symbol_tree.pop_scope_up() - self.assertIsInstance(symbol_tree.active_node.payload, FakeTracker) - - def test_symbol_tree_drop_scope_down_fails_bad_type(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - with self.assertRaises(TypeError): - symbol_tree.drop_scope_down('a') - - def test_symbol_tree_drop_scope_down_moves_to_sentinel(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.ingest_variable_binding('x', None) - symbol_tree.drop_scope_down(0) - self.assertIsInstance( - symbol_tree.active_node.payload, transformation_utils._BeginScopePointer - ) - - def test_symbol_tree_drop_scope_down_equivalent_to_add_child_and_move(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - shadow_symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.drop_scope_down(0) - shadow_symbol_tree._add_child( - 0, - transformation_utils.SequentialBindingNode( - transformation_utils._BeginScopePointer() - ), - ) - shadow_symbol_tree._move_to_child(0) - self.assertEqual(symbol_tree, shadow_symbol_tree) - - def test_symbol_tree_walk_down_bad_variable_binding_fails(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - with self.assertRaisesRegex(ValueError, 'nonexistent variable binding'): - symbol_tree.walk_down_one_variable_binding() - - def test_symbol_tree_walk_down_good_variable_binding_moves_active_node(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - self.assertIsInstance(symbol_tree.active_node.payload, FakeTracker) - - def test_symbol_tree_walk_down_good_variable_binding_moves_to_bound_variable( - self, - ): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.ingest_variable_binding('x', None) - symbol_tree.walk_to_scope_beginning() - symbol_tree.walk_down_one_variable_binding() - self.assertEqual(symbol_tree.get_payload_with_name('x').name, 'x') - self.assertIsNone(symbol_tree.get_payload_with_name('x').value) - - def test_symbol_tree_ingest_variable_binding_bad_args_fails(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - with self.assertRaises(TypeError): - symbol_tree.ingest_variable_binding( - 0, building_blocks.Reference('x', np.int32) - ) - with self.assertRaises(TypeError): - symbol_tree.ingest_variable_binding('x', 0) - - def test_drop_scope_down_and_ingest_variable_binding_adds_node_to_empty_tree( - self, - ): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - shadow_symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.drop_scope_down(0) - symbol_tree.ingest_variable_binding( - 'x', building_blocks.Reference('x', np.int32) - ) - shadow_symbol_tree._add_child( - 0, - transformation_utils.SequentialBindingNode( - transformation_utils._BeginScopePointer() - ), - ) - shadow_symbol_tree._move_to_child(0) - shadow_symbol_tree._add_younger_sibling( - transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', building_blocks.Reference('x', np.int32)) - ) - ) - self.assertEqual(symbol_tree, shadow_symbol_tree) - - def test_ingest_variable_binding_adds_node_to_empty_tree(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - shadow_symbol_tree = transformation_utils.SymbolTree(FakeTracker) - payload_to_add = FakeTracker( - 'x', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - shadow_symbol_tree._add_younger_sibling( - transformation_utils.SequentialBindingNode(payload_to_add) - ) - - symbol_tree.ingest_variable_binding( - payload_to_add.name, payload_to_add.value - ) - - self.assertEqual(symbol_tree, shadow_symbol_tree) - - def test_ingest_variable_binding_adds_node_to_nonempty_tree(self): - symbol_tree = _construct_complex_symbol_tree() - shadow_symbol_tree = _construct_complex_symbol_tree() - payload_to_add = FakeTracker( - 'x', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - shadow_symbol_tree._add_younger_sibling( - transformation_utils.SequentialBindingNode(payload_to_add) - ) - - symbol_tree.ingest_variable_binding( - 'x', building_blocks.Reference('a', np.int32) - ) - - self.assertEqual(symbol_tree, shadow_symbol_tree) - - def test_ingest_variable_overwrites_existing_node_with_same_name(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.drop_scope_down(1) - symbol_tree.ingest_variable_binding( - 'y', building_blocks.Literal(2, computation_types.TensorType(np.int32)) - ) - resolved_y = symbol_tree.get_payload_with_name('y') - self.assertEqual(resolved_y.value.value, 2) - self.assertEqual(str(resolved_y.value.type_signature), 'int32') - symbol_tree.walk_to_scope_beginning() - symbol_tree.ingest_variable_binding( - 'y', - building_blocks.Literal(3.0, computation_types.TensorType(np.float32)), - ) - changed_y = symbol_tree.get_payload_with_name('y') - self.assertEqual(changed_y.value.value, 3.0) - self.assertEqual(str(changed_y.value.type_signature), 'float32') - - def test_ingest_variable_overwrite_leaves_unrelated_node_alone(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree.drop_scope_down(0) - symbol_tree.ingest_variable_binding( - 'x', - building_blocks.Literal(3.0, computation_types.TensorType(np.float32)), - ) - symbol_tree.drop_scope_down(1) - symbol_tree.ingest_variable_binding( - 'y', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - resolved_x = symbol_tree.get_payload_with_name('x') - self.assertEqual(resolved_x.value.value, 3.0) - self.assertEqual(str(resolved_x.value.type_signature), 'float32') - symbol_tree.pop_scope_up() - symbol_tree.drop_scope_down(1) - symbol_tree.ingest_variable_binding( - 'y', - building_blocks.Literal(4.0, computation_types.TensorType(np.float32)), - ) - same_x = symbol_tree.get_payload_with_name('x') - self.assertEqual(same_x.value, resolved_x.value) - - def test_ingest_variable_raises_error_on_name_conflict(self): - symbol_tree = _construct_complex_symbol_tree() - symbol_tree.drop_scope_down(0) - symbol_tree.ingest_variable_binding( - 'x', - building_blocks.Literal(3.0, computation_types.TensorType(np.float32)), - ) - symbol_tree.drop_scope_down(1) - symbol_tree.ingest_variable_binding( - 'y', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - symbol_tree.pop_scope_up() - symbol_tree.drop_scope_down(1) - with self.assertRaises(ValueError): - symbol_tree.ingest_variable_binding( - 'z', - building_blocks.Literal(2, computation_types.TensorType(np.int32)), - ) - - def test_symbol_tree_add_sibling(self): - fake_node = transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ) - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_younger_sibling(fake_node) - symbol_tree.walk_down_one_variable_binding() - self.assertEqual(id(symbol_tree.active_node), id(fake_node)) - self.assertIsNone(symbol_tree.active_node.children.get(0)) - self.assertIsNone(symbol_tree.active_node.younger_sibling) - symbol_tree.walk_to_scope_beginning() - self.assertEqual( - symbol_tree.active_node.payload, - transformation_utils._BeginScopePointer(), - ) - self.assertIsNotNone(symbol_tree.active_node.younger_sibling) - self.assertIsNone(symbol_tree.active_node.children.get(0)) - - def test_symbol_tree_has_younger_sibling(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - self.assertIsNotNone(symbol_tree.active_node.younger_sibling) - - def test_symbol_tree_add_child(self): - fake_node = transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ) - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_child(0, fake_node) - symbol_tree._move_to_child(0) - self.assertEqual(id(symbol_tree.active_node), id(fake_node)) - symbol_tree.active_node = symbol_tree.active_node.parent - self.assertEqual( - symbol_tree.active_node.payload, - transformation_utils._BeginScopePointer(), - ) - - def test_symbol_tree_move_to_bad_child_fails(self): - fake_node = transformation_utils.SequentialBindingNode( - FakeTracker('FakeTracker', None) - ) - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - symbol_tree._add_child(0, fake_node) - with self.assertRaises(ValueError): - symbol_tree._move_to_child(1) - - def test_complicated_symbol_tree_equality(self): - first_tree = _construct_complex_symbol_tree() - second_tree = _construct_complex_symbol_tree() - self.assertEqual(first_tree, second_tree) - second_tree._add_child( - 10, - transformation_utils.SequentialBindingNode(FakeTracker('alpha', None)), - ) - self.assertNotEqual(first_tree, second_tree) - self.assertNotEqual(second_tree, first_tree) - - def test_complicated_symbol_tree_equality_independent_of_active_node(self): - first_tree = _construct_complex_symbol_tree() - second_tree = _construct_complex_symbol_tree() - second_tree.pop_scope_up() - self.assertEqual(first_tree, second_tree) - - def test_complicated_symbol_tree_resolves_string_correctly(self): - symbol_tree = transformation_utils.SymbolTree(FakeTracker) - for _ in range(2): - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - symbol_tree.drop_scope_down(0) - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - for _ in range(2): - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - symbol_tree.pop_scope_up() - symbol_tree.drop_scope_down(1) - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - for k in range(2): - symbol_tree.drop_scope_down(k + 2) - symbol_tree._add_younger_sibling(fake_tracker_node_factory()) - symbol_tree.walk_down_one_variable_binding() - symbol_tree.pop_scope_up() - - self.assertEqual( - str(symbol_tree), - '[BeginScope]-[FakeTracker]-[FakeTracker]->{([BeginScope]-[FakeTracker]-[FakeTracker]-[FakeTracker]),(([BeginScope]-[FakeTracker*]->{([BeginScope]-[FakeTracker]),(([BeginScope]-[FakeTracker])})}', - ) - symbol_tree.pop_scope_up() - self.assertEqual( - str(symbol_tree), - '[BeginScope]-[FakeTracker]-[FakeTracker*]->{([BeginScope]-[FakeTracker]-[FakeTracker]-[FakeTracker]),(([BeginScope]-[FakeTracker]->{([BeginScope]-[FakeTracker]),(([BeginScope]-[FakeTracker])})}', - ) - - def test_trivial_subclass_init_fails_bad_args(self): - with self.assertRaises(TypeError): - TrivialBoundVariableTracker() - with self.assertRaises(TypeError): - TrivialBoundVariableTracker(0, None) - with self.assertRaises(TypeError): - TrivialBoundVariableTracker('x', 0) - - def test_trivial_subclass_init(self): - x = TrivialBoundVariableTracker('x', None) - self.assertEqual(x.name, 'x') - self.assertIsNone(x.value) - - def test_sequential_binding_node_fails_bad_args(self): - with self.assertRaises(TypeError): - transformation_utils.SequentialBindingNode(None) - with self.assertRaises(TypeError): - transformation_utils.SequentialBindingNode(0) - - def test_sequential_binding_node_initialization(self): - trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('trivial_name', None) - ) - - self.assertEqual(trivial_instance.payload.name, 'trivial_name') - self.assertEmpty(trivial_instance.children) - self.assertIsNone(trivial_instance.payload.value) - self.assertIsNone(trivial_instance.parent) - self.assertIsNone(trivial_instance.younger_sibling) - self.assertIsNone(trivial_instance.older_sibling) - - def test_bound_variable_tracker_trivial_subclass_init_bad_args(self): - with self.assertRaises(TypeError): - TrivialBoundVariableTracker(0, None) - with self.assertRaises(TypeError): - TrivialBoundVariableTracker('x', 0) - - def test_sequential_binding_node_parent_child_relationship(self): - trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('trivial_name', None) - ) - second_trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('second_trivial_name', None) - ) - - self.assertNotEqual(trivial_instance, second_trivial_instance) - second_trivial_instance.set_parent(trivial_instance) - trivial_instance.add_child(0, second_trivial_instance) - self.assertEqual(trivial_instance.get_child(0), second_trivial_instance) - self.assertIsNone(trivial_instance.get_child(1)) - self.assertEqual(second_trivial_instance.parent, trivial_instance) - with self.assertRaises(TypeError): - trivial_instance.set_parent(0) - with self.assertRaises(TypeError): - second_trivial_instance.add_child(0, 0) - - def test_sequential_binding_node_sibling_relationship(self): - trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('trivial_name', None) - ) - second_trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('second_trivial_name', None) - ) - - self.assertNotEqual(trivial_instance, second_trivial_instance) - trivial_instance.set_younger_sibling(second_trivial_instance) - self.assertEqual(trivial_instance.younger_sibling, second_trivial_instance) - second_trivial_instance.set_older_sibling(trivial_instance) - self.assertEqual(second_trivial_instance.older_sibling, trivial_instance) - with self.assertRaises(TypeError): - trivial_instance.set_younger_sibling(0) - with self.assertRaises(TypeError): - second_trivial_instance.set_older_sibling(0) - - def test_sequential_binding_nodes_cousin_relationship(self): - trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('trivial_name', None) - ) - second_trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('second_trivial_name', None) - ) - third_trivial_instance = transformation_utils.SequentialBindingNode( - TrivialBoundVariableTracker('third_trivial_name', None) - ) - trivial_instance.add_child(0, second_trivial_instance) - trivial_instance.add_child(1, third_trivial_instance) - second_trivial_instance.set_parent(trivial_instance) - third_trivial_instance.set_parent(trivial_instance) - second_trivial_instance_relations = [ - second_trivial_instance.parent, - second_trivial_instance.older_sibling, - second_trivial_instance.younger_sibling, - ] + list(second_trivial_instance.children.values()) - - third_trivial_instance_relations = [ - third_trivial_instance.parent, - third_trivial_instance.older_sibling, - third_trivial_instance.younger_sibling, - ] + list(third_trivial_instance.children.values()) - self.assertNotIn(second_trivial_instance, third_trivial_instance_relations) - self.assertNotIn(third_trivial_instance, second_trivial_instance_relations) - self.assertEqual( - id(second_trivial_instance.parent), id(third_trivial_instance.parent) - ) - - def test_bound_variable_tracker_equality_names(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - whimsy_tracker = TrivialBoundVariableTracker('x', lit) - second_whimsy_tracker = TrivialBoundVariableTracker('x', lit) - self.assertEqual(whimsy_tracker, second_whimsy_tracker) - second_whimsy_tracker.name = 'y' - self.assertNotEqual(whimsy_tracker, second_whimsy_tracker) - whimsy_tracker.name = 'y' - self.assertEqual(whimsy_tracker, second_whimsy_tracker) - - def test_bound_variable_tracker_equality_values(self): - whimsy_tracker = TrivialBoundVariableTracker( - 'x', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - second_whimsy_tracker = TrivialBoundVariableTracker( - 'x', building_blocks.Literal(2, computation_types.TensorType(np.int32)) - ) - self.assertNotEqual(whimsy_tracker, second_whimsy_tracker) - - def test_outer_context_pointer_equality(self): - outer_context = transformation_utils._BeginScopePointer() - other_outer_context = transformation_utils._BeginScopePointer() - self.assertNotEqual(id(outer_context), id(other_outer_context)) - self.assertEqual(str(outer_context), 'BeginScope') - self.assertEqual(outer_context, other_outer_context) - - def test_outer_context_pointer_cant_update(self): - outer_context = transformation_utils._BeginScopePointer() - with self.assertRaises(RuntimeError): - outer_context.update() - - def test_reference_tracker_initializes(self): - whimsy_tracker = transformation_utils.ReferenceCounter( - 'x', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - self.assertEqual(whimsy_tracker.name, 'x') - self.assertEqual(whimsy_tracker.value.compact_representation(), '1') - self.assertEqual(whimsy_tracker.count, 0) - - def test_reference_tracker_updates(self): - whimsy_tracker = transformation_utils.ReferenceCounter( - 'x', building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ) - for k in range(10): - whimsy_tracker.update() - self.assertEqual(whimsy_tracker.count, k + 1) - - def test_reference_tracker_equality_instances(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - whimsy_tracker = transformation_utils.ReferenceCounter('x', lit) - second_whimsy_tracker = transformation_utils.ReferenceCounter('x', lit) - self.assertEqual(whimsy_tracker, second_whimsy_tracker) - whimsy_tracker.update() - self.assertNotEqual(whimsy_tracker, second_whimsy_tracker) - second_whimsy_tracker.update() - self.assertEqual(whimsy_tracker, second_whimsy_tracker) - - def test_get_count_of_references_to_variables_simple_block(self): - simple_block = _construct_simple_block( - computation_types.TensorType(np.int32) - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - simple_block - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - arg = transformation_utils.ReferenceCounter( - simple_block.locals[0][0], simple_block.locals[0][1] - ) - arg.update() - constructed_context_stack.drop_scope_down(1) - constructed_context_stack._add_younger_sibling( - transformation_utils.SequentialBindingNode(arg) - ) - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_simple_block_two_references( - self, - ): - simple_block = _construct_simple_block( - computation_types.TensorType(np.int32) - ) - ref = simple_block.result - result = building_blocks.Struct([ref, ref]) - simple_block = building_blocks.Block(simple_block.locals, result) - self.assertEqual( - simple_block.compact_representation(), '(let x=1 in )' - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - simple_block - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.ingest_variable_binding( - simple_block.locals[0][0], simple_block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name(ref.name) - constructed_context_stack.update_payload_with_name(ref.name) - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_nested_blocks_conflicting_names( - self, - ): - first_block = _construct_simple_block( - computation_types.TensorType(np.int32) - ) - outer_block_output = building_blocks.Reference( - 'x', first_block.type_signature - ) - second_block = building_blocks.Block( - [('x', first_block)], outer_block_output - ) - - self.assertEqual( - second_block.compact_representation(), '(let x=(let x=1 in x) in x)' - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - second_block - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.drop_scope_down(2) - constructed_context_stack.ingest_variable_binding( - first_block.locals[0][0], first_block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name( - first_block.result.name - ) - constructed_context_stack.pop_scope_up() - constructed_context_stack.ingest_variable_binding( - second_block.locals[0][0], second_block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name( - second_block.result.name - ) - constructed_context_stack.pop_scope_up() - return constructed_context_stack - - self.assertEqual(str(context_stack), str(_construct_context_stack())) - - def test_get_count_of_references_to_variables_block_lambda_name_conflict( - self, - ): - innermost_x = building_blocks.Reference('x', np.int32) - inner_lambda = building_blocks.Lambda('x', np.int32, innermost_x) - second_x = building_blocks.Reference('x', np.int32) - called_lambda = building_blocks.Call(inner_lambda, second_x) - block_input = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) - ) - block = building_blocks.Block([('x', block_input)], called_lambda) - self.assertEqual(block.compact_representation(), '(let x=1 in (x -> x)(x))') - context_stack = transformation_utils.get_count_of_references_to_variables( - block - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(2) - constructed_context_stack.ingest_variable_binding( - block.locals[0][0], block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name(second_x.name) - constructed_context_stack.drop_scope_down(3) - constructed_context_stack.ingest_variable_binding( - inner_lambda.parameter_name, None - ) - constructed_context_stack.update_payload_with_name(innermost_x.name) - constructed_context_stack.pop_scope_up() - constructed_context_stack.pop_scope_up() - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_lambda_name_conflict(self): - inner_x = building_blocks.Reference('x', np.int32) - inner_lambda = building_blocks.Lambda('x', np.int32, inner_x) - outer_x = building_blocks.Reference('x', np.int32) - call = building_blocks.Call(inner_lambda, outer_x) - outer_lambda = building_blocks.Lambda('x', np.int32, call) - self.assertEqual( - outer_lambda.compact_representation(), '(x -> (x -> x)(x))' - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - outer_lambda - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(0) - constructed_context_stack.ingest_variable_binding( - outer_lambda.parameter_name, None - ) - constructed_context_stack.update_payload_with_name(outer_x.name) - constructed_context_stack.drop_scope_down(0) - constructed_context_stack.ingest_variable_binding( - inner_lambda.parameter_name, None - ) - constructed_context_stack.update_payload_with_name(inner_x.name) - constructed_context_stack.pop_scope_up() - constructed_context_stack.pop_scope_up() - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_block_local_overwriting_name_in_scope( - self, - ): - arg_comp = building_blocks.Reference('arg', [np.int32, np.int32]) - selected = building_blocks.Selection(arg_comp, index=0) - internal_arg = building_blocks.Reference('arg', np.int32) - block = building_blocks.Block([('arg', selected)], internal_arg) - lam = building_blocks.Lambda('arg', arg_comp.type_signature, block) - self.assertEqual( - lam.compact_representation(), '(arg -> (let arg=arg[0] in arg))' - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - lam - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.ingest_variable_binding( - lam.parameter_name, None - ) - constructed_context_stack.update_payload_with_name(arg_comp.name) - constructed_context_stack.drop_scope_down(2) - constructed_context_stack.ingest_variable_binding( - block.locals[0][0], block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name(internal_arg.name) - constructed_context_stack.pop_scope_up() - constructed_context_stack.pop_scope_up() - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_nested_block_no_name_conflict( - self, - ): - used1 = building_blocks.Reference('used1', np.int32) - used2 = building_blocks.Literal(2, computation_types.TensorType(np.int32)) - ref = building_blocks.Reference('x', used1.type_signature) - lower_block = building_blocks.Block([('x', used1)], ref) - higher_block = building_blocks.Block([('used1', used2)], lower_block) - self.assertEqual( - higher_block.compact_representation(), - '(let used1=2 in (let x=used1 in x))', - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - higher_block - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.ingest_variable_binding( - higher_block.locals[0][0], higher_block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name(used1.name) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.ingest_variable_binding( - lower_block.locals[0][0], lower_block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name(ref.name) - constructed_context_stack.pop_scope_up() - constructed_context_stack.pop_scope_up() - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_nested_block_no_overwrite(self): - used1 = building_blocks.Reference('used1', np.int32) - used2 = building_blocks.Literal(2, computation_types.TensorType(np.int32)) - user_inlined_lower_block = building_blocks.Block([('x', used1)], used1) - user_inlined_higher_block = building_blocks.Block( - [('used1', used2)], user_inlined_lower_block - ) - self.assertEqual( - user_inlined_higher_block.compact_representation(), - '(let used1=2 in (let x=used1 in used1))', - ) - second_context_stack = ( - transformation_utils.get_count_of_references_to_variables( - user_inlined_higher_block - ) - ) - - def _construct_second_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.ingest_variable_binding( - user_inlined_higher_block.locals[0][0], - user_inlined_higher_block.locals[0][1], - ) - constructed_context_stack.update_payload_with_name(used1.name) - constructed_context_stack.update_payload_with_name(used1.name) - constructed_context_stack.drop_scope_down(3) - constructed_context_stack.ingest_variable_binding( - user_inlined_lower_block.locals[0][0], - user_inlined_lower_block.locals[0][1], - ) - constructed_context_stack.pop_scope_up() - constructed_context_stack.pop_scope_up() - return constructed_context_stack - - self.assertEqual(second_context_stack, _construct_second_context_stack()) - - def test_get_count_of_references_to_variables_mixed_scope(self): - innermost = building_blocks.Reference('x', np.int32) - intermediate_arg = building_blocks.Reference('y', np.int32) - inner_block = building_blocks.Block([('x', intermediate_arg)], innermost) - item1 = building_blocks.Reference('x', np.int32) - mediate_tuple = building_blocks.Struct([item1, inner_block]) - used = building_blocks.Literal(0, computation_types.TensorType(np.int32)) - used1 = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - outer_block = building_blocks.Block( - [('x', used), ('y', used1)], mediate_tuple - ) - self.assertEqual( - outer_block.compact_representation(), - '(let x=0,y=1 in )', - ) - context_stack = transformation_utils.get_count_of_references_to_variables( - outer_block - ) - - def _construct_context_stack(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(1) - constructed_context_stack.ingest_variable_binding( - outer_block.locals[0][0], outer_block.locals[0][1] - ) - constructed_context_stack.ingest_variable_binding( - outer_block.locals[1][0], outer_block.locals[1][1] - ) - constructed_context_stack.update_payload_with_name(intermediate_arg.name) - constructed_context_stack.update_payload_with_name(item1.name) - constructed_context_stack.drop_scope_down(6) - constructed_context_stack.ingest_variable_binding( - inner_block.locals[0][0], inner_block.locals[0][1] - ) - constructed_context_stack.update_payload_with_name(innermost.name) - return constructed_context_stack - - self.assertEqual(context_stack, _construct_context_stack()) - - def test_get_count_of_references_to_variables_sequential_overwrite_in_block_locals( - self, - ): - tensor_type = computation_types.TensorType(np.int32) - proto = mock.create_autospec( - computation_pb2.Computation, spec_set=True, instance=True - ) - function_type = computation_types.FunctionType(None, tensor_type) - compiled = building_blocks.CompiledComputation( - proto, name='make_10', type_signature=function_type - ) - make_10 = building_blocks.Call(compiled, None) - - whimsy_x_reference = building_blocks.Reference('x', np.int32) - - make_13 = building_blocks.Block( - [ - ('x', make_10), - ('x', whimsy_x_reference), - ('x', whimsy_x_reference), - ('x', whimsy_x_reference), - ], - whimsy_x_reference, - ) - - references = transformation_utils.get_count_of_references_to_variables( - make_13 - ) - - child_id = list(references.active_node.children.keys())[0] - - def _make_context_tree(): - constructed_context_stack = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - constructed_context_stack.drop_scope_down(child_id) - constructed_context_stack.ingest_variable_binding( - make_13.locals[0][0], make_13.locals[0][1] - ) - constructed_context_stack.update_payload_with_name( - whimsy_x_reference.name - ) - constructed_context_stack.ingest_variable_binding( - make_13.locals[1][0], make_13.locals[1][1] - ) - constructed_context_stack.update_payload_with_name( - whimsy_x_reference.name - ) - constructed_context_stack.ingest_variable_binding( - make_13.locals[2][0], make_13.locals[2][1] - ) - constructed_context_stack.update_payload_with_name( - whimsy_x_reference.name - ) - constructed_context_stack.ingest_variable_binding( - make_13.locals[3][0], make_13.locals[3][1] - ) - constructed_context_stack.update_payload_with_name( - whimsy_x_reference.name - ) - constructed_context_stack.walk_to_scope_beginning() - return constructed_context_stack - - constructed_tree = _make_context_tree() - self.assertEqual(references, constructed_tree) - - -class TransformPreorderTest(parameterized.TestCase): - - def test_transform_preorder_fails_on_none_comp(self): - def transform(comp): - return comp, False - - with self.assertRaises(TypeError): - transformation_utils.transform_preorder(None, transform) - - def test_transform_preorder_fails_on_none_transform(self): - comp = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - transformation_utils.transform_preorder(comp, None) - - def test_transform_preorder_with_lambda_call_selection_and_reference(self): - function_type = computation_types.FunctionType(np.int32, np.int32) - ref = building_blocks.Reference('FEDERATED_arg', [function_type, np.int32]) - fn = building_blocks.Selection(ref, index=0) - arg = building_blocks.Selection(ref, index=1) - call = building_blocks.Call(fn, arg) - comp = building_blocks.Lambda(ref.name, np.int32, call) - self.assertEqual( - comp.compact_representation(), - '(FEDERATED_arg -> FEDERATED_arg[0](FEDERATED_arg[1]))', - ) - - def _transformation_fn_generator(): - n = 0 - while True: - n = n + 1 - - def _fn(x): - intrinsic_type = computation_types.FunctionType( - x.type_signature, x.type_signature - ) - intrinsic = building_blocks.Intrinsic('F{}'.format(n), intrinsic_type) - call = building_blocks.Call(intrinsic, x) - return call, True - - yield _fn - - transformation_fn_sequence = _transformation_fn_generator() - - def tx_fn(x): - return next(transformation_fn_sequence)(x) - - transfomed_comp, modified = transformation_utils.transform_preorder( - comp, tx_fn - ) - self.assertTrue(modified) - self.assertEqual( - transfomed_comp.compact_representation(), - 'F1((FEDERATED_arg -> FEDERATED_arg[0](FEDERATED_arg[1])))', - ) - self.assertTrue(modified) - - @parameterized.named_parameters( - _construct_trivial_instance_of_all_computation_building_blocks() - + [( - 'complex_tree', - building_block_test_utils.create_nested_syntax_tree(), - )] - ) - def test_transform_preorder_returns_untransformed(self, comp): - def transform_noop(comp): - return comp, False - - same_comp, modified = transformation_utils.transform_preorder( - comp, transform_noop - ) - self.assertEqual( - same_comp.compact_representation(), comp.compact_representation() - ) - self.assertFalse(modified) - - @parameterized.named_parameters( - _construct_trivial_instance_of_all_computation_building_blocks() - ) - def test_transform_preorder_does_not_construct_new_internal(self, comp): - def transform_noop(comp): - return comp, False - - same_comp, modified = transformation_utils.transform_preorder( - comp, transform_noop - ) - - self.assertEqual(comp, same_comp) - self.assertFalse(modified) - - def test_transform_preorder_hits_all_nodes_once(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - self.assertEqual( - _get_number_of_nodes_via_transform_preorder(complex_ast), 22 - ) - - def test_transform_preorder_walks_to_leaves_in_preorder(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - leaf_name_order = [] - - def transform(comp): - if isinstance(comp, building_blocks.Literal): - leaf_name_order.append(comp.value) - return comp, False - - transformation_utils.transform_preorder(complex_ast, transform) - - self.assertEqual(leaf_name_order, [1, 2, 3, 4, 5, 6, 7, 10, 8, 9, 11]) - - def test_transform_preorder_walks_block_locals_preorder(self): - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - leaf_name_order = [] - - def transform(comp): - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - leaf_name_order.append(name) - return comp, False - - transformation_utils.transform_preorder(complex_ast, transform) - - self.assertEqual(leaf_name_order, ['y', 'z', 'v', 't', 'u', 'x', 'w']) - - def test_transform_preorder_walks_through_all_internal_nodes_preorder(self): - """Checks `transform_preorder` walks correctly through any internal node. - - This test is split from the one above because it tests extra cases - in `transform_preorder`; in particular, all instances of - `building_blocks.ComputationBuildingBlock` which kick off - recursive calls of `transform_preorder` are exercised in this test, - while only a subset are exercised in the above. For example, if the - logic ingesting a `Call` breaks, this test will fail and the one above - may pass. - """ - complex_ast = building_block_test_utils.create_nested_syntax_tree() - - leaf_name_order = [] - - def transform(comp): - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - leaf_name_order.append(name) - elif isinstance(comp, building_blocks.Literal): - leaf_name_order.append(comp.value) - return comp, False - - transformation_utils.transform_preorder(complex_ast, transform) - preorder_nodes = [ - 'y', - 'z', - 1, - 2, - 'v', - 't', - 3, - 4, - 'u', - 5, - 6, - 7, - 'x', - 10, - 'w', - 8, - 9, - 11, - ] - - self.assertEqual(leaf_name_order, list(preorder_nodes)) - - def test_transform_preorder_passes_transform_through_tuple_correctly(self): - - def transform_intrinsic_to_reference(comp): - if isinstance(comp, building_blocks.Literal): - return ( - building_blocks.Reference(str(comp.value), comp.type_signature), - True, - ) - return comp, False - - tuple_holding_data = building_blocks.Struct( - [building_blocks.Literal(1, computation_types.TensorType(np.int32))] - ) - literal_replaced, modified = transformation_utils.transform_preorder( - tuple_holding_data, transform_intrinsic_to_reference - ) - self.assertTrue(modified) - self.assertEqual( - literal_replaced.compact_representation(), - tuple_holding_data.compact_representation(), - ) - self.assertLen(literal_replaced, 1) - self.assertIsInstance(literal_replaced[0], building_blocks.Reference) - - -class GetUniqueNamesTest(absltest.TestCase): - - def test_raises_on_none(self): - with self.assertRaises(TypeError): - transformation_utils.get_unique_names(None) - - def test_returns_names_single_lambda(self): - ref = building_blocks.Reference('x', np.int32) - lambda_1 = building_blocks.Lambda('x', np.int32, ref) - names = transformation_utils.get_unique_names(lambda_1) - self.assertCountEqual(names, ('x',)) - - def test_returns_names_nested_lambdas_with_different_variable_name(self): - ref = building_blocks.Reference('x', np.int32) - lambda_1 = building_blocks.Lambda('x', np.int32, ref) - lambda_2 = building_blocks.Lambda('y', np.int32, lambda_1) - names = transformation_utils.get_unique_names(lambda_2) - self.assertCountEqual(names, ('x', 'y')) - - def test_returns_names_single_block(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block = building_blocks.Block([('x', lit)], lit) - names = transformation_utils.get_unique_names(block) - self.assertCountEqual(names, ('x',)) - - def test_returns_names_nested_blocks_with_different_variable_name(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block_1 = building_blocks.Block([('x', lit)], lit) - block_2 = building_blocks.Block([('y', lit)], block_1) - names = transformation_utils.get_unique_names(block_2) - self.assertCountEqual(names, ('x', 'y')) - - def test_captures_reference_name(self): - ref_to_x = building_blocks.Reference('x', np.int32) - names = transformation_utils.get_unique_names(ref_to_x) - self.assertCountEqual(names, 'x') - - def test_captures_unbound_reference_name(self): - ref_to_z = building_blocks.Reference('z', np.int32) - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block_1 = building_blocks.Block([('x', lit)], ref_to_z) - block_2 = building_blocks.Block([('y', lit)], block_1) - names = transformation_utils.get_unique_names(block_2) - self.assertCountEqual(names, ('x', 'y', 'z')) - - -class GetMapOfUnboundReferencesTest(absltest.TestCase): - - def test_lambda_under_call_to_ref_gets_nothing_unbound(self): - y_ref = building_blocks.Reference('y', np.int32) - lambda_1 = building_blocks.Lambda('y', y_ref.type_signature, y_ref) - x_ref = building_blocks.Reference('x', np.int32) - call_on_x_ref = building_blocks.Call(lambda_1, x_ref) - unbound_refs = transformation_utils.get_map_of_unbound_references( - call_on_x_ref - )[lambda_1] - self.assertEmpty(unbound_refs) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/transformations.py b/tensorflow_federated/python/core/impl/compiler/transformations.py index 493471f201..287a5a60ba 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformations.py +++ b/tensorflow_federated/python/core/impl/compiler/transformations.py @@ -21,21 +21,16 @@ from collections.abc import Collection, Sequence import attrs +import federated_language from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.types import computation_types def to_call_dominant( - comp: building_blocks.ComputationBuildingBlock, -) -> building_blocks.ComputationBuildingBlock: + comp: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.ComputationBuildingBlock: """Transforms local (non-federated) computations into call-dominant form. Args: @@ -65,7 +60,7 @@ def to_call_dominant( # Top-level comp must be a lambda to ensure that we create a set of bindings # immediately under it, as `_build` does for all lambdas. global_comp = comp - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) class _Scope: """Name resolution scopes which track the creation of new value bindings.""" @@ -104,7 +99,9 @@ def create_binding(self, value): else: name = next(name_generator) self._newly_bound_values.append((name, value)) - reference = building_blocks.Reference(name, value.type_signature) + reference = federated_language.framework.Reference( + name, value.type_signature + ) self._locals[name] = reference return reference @@ -122,53 +119,61 @@ def bindings_to_block_with_result(self, result): if not self._newly_bound_values: return result else: - return building_blocks.Block(self._newly_bound_values, result) + return federated_language.framework.Block( + self._newly_bound_values, result + ) def _build(comp, scope): """Transforms `comp` to CDF, possibly adding bindings to `scope`.""" # The structure returned by this function is a generalized version of # call-dominant form. This function may result in the patterns specified in # the top-level function's docstring. - if isinstance(comp, building_blocks.Reference): + if isinstance(comp, federated_language.framework.Reference): result = scope.resolve(comp.name) if result is None: # If `comp.name` is only bound outside of `comp`, we can't resolve it. return comp return result - elif isinstance(comp, building_blocks.Selection): + elif isinstance(comp, federated_language.framework.Selection): source = _build(comp.source, scope) - if isinstance(source, building_blocks.Struct): + if isinstance(source, federated_language.framework.Struct): return source[comp.as_index()] - return building_blocks.Selection(source, index=comp.as_index()) - elif isinstance(comp, building_blocks.Struct): + return federated_language.framework.Selection( + source, index=comp.as_index() + ) + elif isinstance(comp, federated_language.framework.Struct): elements = [] for name, value in structure.iter_elements(comp): value = _build(value, scope) elements.append((name, value)) - return building_blocks.Struct(elements) - elif isinstance(comp, building_blocks.Call): + return federated_language.framework.Struct(elements) + elif isinstance(comp, federated_language.framework.Call): function = _build(comp.function, scope) argument = None if comp.argument is None else _build(comp.argument, scope) - if isinstance(function, building_blocks.Lambda): + if isinstance(function, federated_language.framework.Lambda): if argument is not None: scope = scope.new_child() scope.add_local(function.parameter_name, argument) return _build(function.result, scope) else: - return scope.create_binding(building_blocks.Call(function, argument)) - elif isinstance(comp, building_blocks.Lambda): + return scope.create_binding( + federated_language.framework.Call(function, argument) + ) + elif isinstance(comp, federated_language.framework.Lambda): scope = scope.new_child_with_bindings() if comp.parameter_name: scope.add_local( comp.parameter_name, - building_blocks.Reference(comp.parameter_name, comp.parameter_type), + federated_language.framework.Reference( + comp.parameter_name, comp.parameter_type + ), ) result = _build(comp.result, scope) block = scope.bindings_to_block_with_result(result) - return building_blocks.Lambda( + return federated_language.framework.Lambda( comp.parameter_name, comp.parameter_type, block ) - elif isinstance(comp, building_blocks.Block): + elif isinstance(comp, federated_language.framework.Block): scope = scope.new_child() for name, value in comp.locals: scope.add_local(name, _build(value, scope)) @@ -176,11 +181,11 @@ def _build(comp, scope): elif isinstance( comp, ( - building_blocks.CompiledComputation, - building_blocks.Data, - building_blocks.Intrinsic, - building_blocks.Literal, - building_blocks.Placement, + federated_language.framework.CompiledComputation, + federated_language.framework.Data, + federated_language.framework.Intrinsic, + federated_language.framework.Literal, + federated_language.framework.Placement, ), ): return comp @@ -201,9 +206,9 @@ def _build(comp, scope): def get_normalized_call_dominant_lambda( - comp: building_blocks.Lambda, + comp: federated_language.framework.Lambda, normalize_all_equal_bit: bool = True, -) -> building_blocks.Lambda: +) -> federated_language.framework.Lambda: """Creates normalized call dominant form for a lambda computation. Args: @@ -215,7 +220,7 @@ def get_normalized_call_dominant_lambda( lambda computation in CDF (call-dominant form) and the result component of the lambda is guaranteed to be a block. """ - py_typecheck.check_type(comp, building_blocks.Lambda) + py_typecheck.check_type(comp, federated_language.framework.Lambda) # Simplify the `comp` before transforming it to call-dominant form. comp, _ = tree_transformations.remove_mapped_or_applied_identity(comp) @@ -227,14 +232,16 @@ def get_normalized_call_dominant_lambda( # CDF can potentially return blocks if there are variables not dependent on # the top-level parameter. We normalize these away. - if not isinstance(comp, building_blocks.Lambda): - if not isinstance(comp, building_blocks.Block): - raise building_blocks.UnexpectedBlockError(building_blocks.Block, comp) - if not isinstance(comp.result, building_blocks.Lambda): - raise building_blocks.UnexpectedBlockError( - building_blocks.Lambda, comp.result + if not isinstance(comp, federated_language.framework.Lambda): + if not isinstance(comp, federated_language.framework.Block): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Block, comp + ) + if not isinstance(comp.result, federated_language.framework.Lambda): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Lambda, comp.result ) - if isinstance(comp.result.result, building_blocks.Block): + if isinstance(comp.result.result, federated_language.framework.Block): additional_locals = comp.result.result.locals result = comp.result.result.result else: @@ -244,28 +251,30 @@ def get_normalized_call_dominant_lambda( # shadow `comp.result.parameter_name`. However, `to_call_dominant` # above ensure that names are unique, as it ends in a call to # `uniquify_reference_names`. - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( comp.result.parameter_name, comp.result.parameter_type, - building_blocks.Block(comp.locals + additional_locals, result), + federated_language.framework.Block( + comp.locals + additional_locals, result + ), ) # Simple computations with no intrinsic calls won't have a block. # Normalize these as well. - if not isinstance(comp.result, building_blocks.Block): - comp = building_blocks.Lambda( + if not isinstance(comp.result, federated_language.framework.Block): + comp = federated_language.framework.Lambda( comp.parameter_name, comp.parameter_type, - building_blocks.Block([], comp.result), + federated_language.framework.Block([], comp.result), ) comp = tree_transformations.normalize_types(comp, normalize_all_equal_bit) - tree_analysis.check_contains_no_unbound_references(comp) + federated_language.framework.check_contains_no_unbound_references(comp) return comp -_NamedBinding = tuple[str, building_blocks.Call] +_NamedBinding = tuple[str, federated_language.framework.Call] @attrs.define @@ -300,7 +309,7 @@ def _compute_intrinsic_dependencies( intrinsic_dependencies = set() def record_dependencies(subvalue): - if isinstance(subvalue, building_blocks.Reference): + if isinstance(subvalue, federated_language.framework.Reference): if subvalue.name not in intrinsic_dependencies_for_ref: names = [(n, v.compact_representation()) for n, v in locals_list] raise ValueError( @@ -311,14 +320,14 @@ def record_dependencies(subvalue): intrinsic_dependencies.update( # pylint: disable=cell-var-from-loop intrinsic_dependencies_for_ref[subvalue.name] ) - elif isinstance(subvalue, building_blocks.Lambda): + elif isinstance(subvalue, federated_language.framework.Lambda): # We treat the lambdas that appear in CDF (inside intrinsic invocations) # as though their parameters are independent of the rest of the # computation. Note that we're not careful about saving and then # restoring old variables here: this is okay because call-dominant form # guarantees unique variable names. intrinsic_dependencies_for_ref[subvalue.parameter_name] = set() - elif isinstance(subvalue, building_blocks.Block): + elif isinstance(subvalue, federated_language.framework.Block): # Since we're in CDF, the only blocks inside the bodies of arguments # are within lambda arguments to intrinsics. We don't need to record # dependencies of these since they can't rely on the results of other @@ -326,12 +335,16 @@ def record_dependencies(subvalue): for subvalue_local_name, _ in subvalue.locals: intrinsic_dependencies_for_ref[subvalue_local_name] = set() - tree_analysis.visit_preorder(local_value, record_dependencies) + federated_language.framework.visit_preorder( + local_value, record_dependencies + ) # All intrinsic calls are guaranteed to be top-level in call-dominant form. if ( - isinstance(local_value, building_blocks.Call) - and isinstance(local_value.function, building_blocks.Intrinsic) + isinstance(local_value, federated_language.framework.Call) + and isinstance( + local_value.function, federated_language.framework.Intrinsic + ) and local_value.function.uri in intrinsic_uris ): if intrinsic_dependencies: @@ -360,13 +373,13 @@ def record_dependencies(subvalue): @attrs.define class _MergedIntrinsic: uri: str - args: building_blocks.ComputationBuildingBlock - return_type: computation_types.Type + args: federated_language.framework.ComputationBuildingBlock + return_type: federated_language.Type unpack_to_locals: list[str] def _compute_merged_intrinsics( - intrinsic_defaults: list[building_blocks.Call], + intrinsic_defaults: list[federated_language.framework.Call], uri_to_locals: dict[str, list[_NamedBinding]], name_generator, ) -> list[_MergedIntrinsic]: @@ -388,10 +401,12 @@ def _compute_merged_intrinsics( """ results = [] for default_call in intrinsic_defaults: - if not isinstance(default_call.function, building_blocks.Intrinsic): + if not isinstance( + default_call.function, federated_language.framework.Intrinsic + ): raise ValueError( "Expected 'default_call.function' to be a " - '`building_blocks.Intrinsic`, found ' + '`federated_language.framework.Intrinsic`, found ' f'`{type(default_call.function)}`.' ) uri = default_call.function.uri @@ -418,8 +433,8 @@ def _compute_merged_intrinsics( 'encountered call with all_equal value ' f'{call.type_signature.all_equal}' # pytype: disable=attribute-error ) - return_type = computation_types.FederatedType( - computation_types.StructType( + return_type = federated_language.FederatedType( + federated_language.StructType( [(None, call.type_signature.member) for call in calls] # pytype: disable=attribute-error ), placement=result_placement, @@ -445,9 +460,9 @@ def _compute_merged_intrinsics( def _merge_args( abstract_parameter_type, - args: list[building_blocks.ComputationBuildingBlock], + args: list[federated_language.framework.ComputationBuildingBlock], name_generator, -) -> building_blocks.ComputationBuildingBlock: +) -> federated_language.framework.ComputationBuildingBlock: """Merges the arguments of multiple function invocations into one. Args: @@ -460,9 +475,9 @@ def _merge_args( Returns: A building block to use as the new (merged) argument. """ - if isinstance(abstract_parameter_type, computation_types.FederatedType): - zip_args = building_block_factory.create_federated_zip( - building_blocks.Struct(args) + if isinstance(abstract_parameter_type, federated_language.FederatedType): + zip_args = federated_language.framework.create_federated_zip( + federated_language.framework.Struct(args) ) # `create_federated_zip` introduces repeated names. zip_args, _ = tree_transformations.uniquify_reference_names( @@ -472,12 +487,12 @@ def _merge_args( if isinstance( abstract_parameter_type, ( - computation_types.AbstractType, - computation_types.TensorType, + federated_language.AbstractType, + federated_language.TensorType, ), ): - return building_blocks.Struct([(None, arg) for arg in args]) - if isinstance(abstract_parameter_type, computation_types.FunctionType): + return federated_language.framework.Struct([(None, arg) for arg in args]) + if isinstance(abstract_parameter_type, federated_language.FunctionType): # For functions, we must compose them differently depending on whether the # abstract function (from the intrinsic definition) takes more than one # parameter. @@ -498,70 +513,86 @@ def _merge_args( # )` param_name = next(name_generator) if isinstance( - abstract_parameter_type.parameter, computation_types.StructType + abstract_parameter_type.parameter, federated_language.StructType ): num_args = len(abstract_parameter_type.parameter) parameter_types = [[] for _ in range(num_args)] for arg in args: for i in range(num_args): parameter_types[i].append(arg.type_signature.parameter[i]) # pytype: disable=attribute-error - param_type = computation_types.StructType(parameter_types) - param_ref = building_blocks.Reference(param_name, param_type) + param_type = federated_language.StructType(parameter_types) + param_ref = federated_language.framework.Reference(param_name, param_type) calls = [] for n, fn in enumerate(args): args_to_fn = [] for i in range(num_args): args_to_fn.append( - building_blocks.Selection( - building_blocks.Selection(param_ref, index=i), index=n + federated_language.framework.Selection( + federated_language.framework.Selection(param_ref, index=i), + index=n, ) ) calls.append( - building_blocks.Call( - fn, building_blocks.Struct([(None, arg) for arg in args_to_fn]) + federated_language.framework.Call( + fn, + federated_language.framework.Struct( + [(None, arg) for arg in args_to_fn] + ), ) ) else: - param_type = computation_types.StructType( + param_type = federated_language.StructType( [arg.type_signature.parameter for arg in args] # pytype: disable=attribute-error ) - param_ref = building_blocks.Reference(param_name, param_type) + param_ref = federated_language.framework.Reference(param_name, param_type) calls = [ - building_blocks.Call( - fn, building_blocks.Selection(param_ref, index=n) + federated_language.framework.Call( + fn, federated_language.framework.Selection(param_ref, index=n) ) for (n, fn) in enumerate(args) ] - return building_blocks.Lambda( + return federated_language.framework.Lambda( parameter_name=param_name, parameter_type=param_type, - result=building_blocks.Struct([(None, call) for call in calls]), + result=federated_language.framework.Struct( + [(None, call) for call in calls] + ), ) - if isinstance(abstract_parameter_type, computation_types.StructType): + if isinstance(abstract_parameter_type, federated_language.StructType): # Bind each argument to a name so that we can reference them multiple times. arg_locals = [] arg_refs = [] for arg in args: arg_name = next(name_generator) arg_locals.append((arg_name, arg)) - arg_refs.append(building_blocks.Reference(arg_name, arg.type_signature)) + arg_refs.append( + federated_language.framework.Reference(arg_name, arg.type_signature) + ) merged_args = [] for i, _ in enumerate(abstract_parameter_type): - ith_args = [building_blocks.Selection(ref, index=i) for ref in arg_refs] + ith_args = [ + federated_language.framework.Selection(ref, index=i) + for ref in arg_refs + ] merged_args.append( _merge_args(abstract_parameter_type[i], ith_args, name_generator) ) - return building_blocks.Block( - arg_locals, building_blocks.Struct([(None, arg) for arg in merged_args]) + return federated_language.framework.Block( + arg_locals, + federated_language.framework.Struct( + [(None, arg) for arg in merged_args] + ), ) raise TypeError(f'Cannot merge args of type: {abstract_parameter_type}') # TODO: b/266565233 - Remove during MapReduceForm and BroadcastForm cleanup. def force_align_and_split_by_intrinsics( - comp: building_blocks.Lambda, - intrinsic_defaults: list[building_blocks.Call], -) -> tuple[building_blocks.Lambda, building_blocks.Lambda]: + comp: federated_language.framework.Lambda, + intrinsic_defaults: list[federated_language.framework.Call], +) -> tuple[ + federated_language.framework.Lambda, federated_language.framework.Lambda +]: """Divides `comp` into before-and-after of calls to one or more intrinsics. The input computation `comp` must have the following properties: @@ -589,7 +620,8 @@ def force_align_and_split_by_intrinsics( `f(merged_arg).member = (f1(f1_arg).member, f2(f2_arg).member)` Under these conditions, (and assuming `comp` is a computation with non-`None` - argument), this function will return two `building_blocks.Lambda`s `before` + argument), this function will return two + `federated_language.framework.Lambda`s `before` and `after` such that `comp` is semantically equivalent to the following expression*: @@ -634,8 +666,8 @@ def force_align_and_split_by_intrinsics( original argument to `comp`, as it may be dependent on both. Args: - comp: The instance of `building_blocks.Lambda` that serves as the input to - this transformation, as described above. + comp: The instance of `federated_language.framework.Lambda` that serves as + the input to this transformation, as described above. intrinsic_defaults: A list of intrinsics with which to split the computation, provided as a list of `Call`s to insert if no intrinsic with a matching URI is found. Intrinsics in this list will be merged, and @@ -643,10 +675,11 @@ def force_align_and_split_by_intrinsics( Returns: A pair of the form `(before, after)`, where each of `before` and `after` - is a `building_blocks.ComputationBuildingBlock` instance that represents a + is a `federated_language.framework.ComputationBuildingBlock` instance that + represents a part of the result as specified above. """ - py_typecheck.check_type(comp, building_blocks.Lambda) + py_typecheck.check_type(comp, federated_language.framework.Lambda) py_typecheck.check_type(intrinsic_defaults, list) comp_repr = comp.compact_representation() @@ -657,14 +690,16 @@ def force_align_and_split_by_intrinsics( # CDF can potentially return blocks if there are variables not dependent on # the top-level parameter. We normalize these away. - if not isinstance(comp, building_blocks.Lambda): - if not isinstance(comp, building_blocks.Block): - raise building_blocks.UnexpectedBlockError(building_blocks.Block, comp) - if not isinstance(comp.result, building_blocks.Lambda): - raise building_blocks.UnexpectedBlockError( - building_blocks.Lambda, comp.result + if not isinstance(comp, federated_language.framework.Lambda): + if not isinstance(comp, federated_language.framework.Block): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Block, comp ) - if isinstance(comp.result.result, building_blocks.Block): + if not isinstance(comp.result, federated_language.framework.Lambda): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Lambda, comp.result + ) + if isinstance(comp.result.result, federated_language.framework.Block): additional_locals = comp.result.result.locals result = comp.result.result.result else: @@ -674,22 +709,24 @@ def force_align_and_split_by_intrinsics( # shadow `comp.result.parameter_name`. However, `to_call_dominant` # above ensure that names are unique, as it ends in a call to # `uniquify_reference_names`. - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( comp.result.parameter_name, comp.result.parameter_type, - building_blocks.Block(comp.locals + additional_locals, result), + federated_language.framework.Block( + comp.locals + additional_locals, result + ), ) # Simple computations with no intrinsic calls won't have a block. # Normalize these as well. - if not isinstance(comp.result, building_blocks.Block): - comp = building_blocks.Lambda( + if not isinstance(comp.result, federated_language.framework.Block): + comp = federated_language.framework.Lambda( comp.parameter_name, comp.parameter_type, - building_blocks.Block([], comp.result), + federated_language.framework.Block([], comp.result), ) - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) intrinsic_uris = set(call.function.uri for call in intrinsic_defaults) deps = _compute_intrinsic_dependencies( @@ -701,17 +738,15 @@ def force_align_and_split_by_intrinsics( # Note: the outputs are labeled as `{uri}_param for convenience, e.g. # `federated_secure_sum_param: ...`. - before = building_blocks.Lambda( + before = federated_language.framework.Lambda( comp.parameter_name, comp.parameter_type, - building_blocks.Block( + federated_language.framework.Block( deps.locals_not_dependent_on_intrinsics, - building_blocks.Struct( - [ - (f'{merged.uri}_param', merged.args) - for merged in merged_intrinsics - ] - ), + federated_language.framework.Struct([ + (f'{merged.uri}_param', merged.args) + for merged in merged_intrinsics + ]), ), ) @@ -719,39 +754,35 @@ def force_align_and_split_by_intrinsics( if comp.parameter_type is not None: # TODO: b/147499373 - If None-arguments were uniformly represented as empty # tuples, we would be able to avoid this (and related) ugly casing. - after_param_type = computation_types.StructType([ + after_param_type = federated_language.StructType([ ('original_arg', comp.parameter_type), ( 'intrinsic_results', - computation_types.StructType( - [ - (f'{merged.uri}_result', merged.return_type) - for merged in merged_intrinsics - ] - ), + federated_language.StructType([ + (f'{merged.uri}_result', merged.return_type) + for merged in merged_intrinsics + ]), ), ]) else: - after_param_type = computation_types.StructType( - [ - ( - 'intrinsic_results', - computation_types.StructType( - [ - (f'{merged.uri}_result', merged.return_type) - for merged in merged_intrinsics - ] - ), - ), - ] - ) - after_param_ref = building_blocks.Reference( + after_param_type = federated_language.StructType([ + ( + 'intrinsic_results', + federated_language.StructType([ + (f'{merged.uri}_result', merged.return_type) + for merged in merged_intrinsics + ]), + ), + ]) + after_param_ref = federated_language.framework.Reference( after_param_name, after_param_type ) if comp.parameter_type is not None: original_arg_bindings = [( comp.parameter_name, - building_blocks.Selection(after_param_ref, name='original_arg'), + federated_language.framework.Selection( + after_param_ref, name='original_arg' + ), )] else: original_arg_bindings = [] @@ -759,29 +790,33 @@ def force_align_and_split_by_intrinsics( unzip_bindings = [] for merged in merged_intrinsics: if merged.unpack_to_locals: - intrinsic_result = building_blocks.Selection( - building_blocks.Selection(after_param_ref, name='intrinsic_results'), + intrinsic_result = federated_language.framework.Selection( + federated_language.framework.Selection( + after_param_ref, name='intrinsic_results' + ), name=f'{merged.uri}_result', ) select_param_type = intrinsic_result.type_signature.member for i, binding_name in enumerate(merged.unpack_to_locals): select_param_name = next(name_generator) - select_param_ref = building_blocks.Reference( + select_param_ref = federated_language.framework.Reference( select_param_name, select_param_type ) - selected = building_block_factory.create_federated_map_or_apply( - building_blocks.Lambda( + selected = federated_language.framework.create_federated_map_or_apply( + federated_language.framework.Lambda( select_param_name, select_param_type, - building_blocks.Selection(select_param_ref, index=i), + federated_language.framework.Selection( + select_param_ref, index=i + ), ), intrinsic_result, ) unzip_bindings.append((binding_name, selected)) - after = building_blocks.Lambda( + after = federated_language.framework.Lambda( after_param_name, after_param_type, - building_blocks.Block( + federated_language.framework.Block( original_arg_bindings + # Note that we must duplicate `locals_not_dependent_on_intrinsics` # across both the `before` and `after` computations since both can @@ -796,17 +831,19 @@ def force_align_and_split_by_intrinsics( ), ) try: - tree_analysis.check_has_unique_names(before) - tree_analysis.check_has_unique_names(after) - except tree_analysis.NonuniqueNameError as e: + federated_language.framework.check_has_unique_names(before) + federated_language.framework.check_has_unique_names(after) + except federated_language.framework.NonuniqueNameError as e: raise ValueError(f'nonunique names in result of splitting\n{comp}') from e return before, after def _augment_lambda_with_parameter_for_unbound_references( - comp: building_blocks.Lambda, lambda_parameter_extension_name: str + comp: federated_language.framework.Lambda, + lambda_parameter_extension_name: str, ) -> tuple[ - building_blocks.Lambda, list[building_blocks.ComputationBuildingBlock] + federated_language.framework.Lambda, + list[federated_language.framework.ComputationBuildingBlock], ]: """Resolves unbound references in `comp` by extending the input parameter. @@ -846,9 +883,9 @@ def _augment_lambda_with_parameter_for_unbound_references( that are unsupported. """ - py_typecheck.check_type(comp, building_blocks.Lambda) + py_typecheck.check_type(comp, federated_language.framework.Lambda) py_typecheck.check_type( - comp.type_signature.parameter, computation_types.StructType + comp.type_signature.parameter, federated_language.StructType ) comp_parameter_name = comp.parameter_name @@ -861,8 +898,10 @@ def _check_input_parameter_used_via_selection(inner_comp): # "transformed" so that we skip traversal of the inner reference subtree # below. if ( - isinstance(inner_comp, building_blocks.Selection) - and isinstance(inner_comp.source, building_blocks.Reference) + isinstance(inner_comp, federated_language.framework.Selection) + and isinstance( + inner_comp.source, federated_language.framework.Reference + ) and inner_comp.source.name == comp_parameter_name ): return inner_comp, True @@ -873,7 +912,7 @@ def _check_input_parameter_used_via_selection(inner_comp): # later step when we attempt to replace the input parameter with an # augmented one. if ( - isinstance(inner_comp, building_blocks.Reference) + isinstance(inner_comp, federated_language.framework.Reference) and inner_comp.name == comp_parameter_name ): raise ValueError( @@ -884,11 +923,13 @@ def _check_input_parameter_used_via_selection(inner_comp): # Trace the computation to ensure that the input parameter is always used via # a selection and never used directly. - transformation_utils.transform_preorder( + federated_language.framework.transform_preorder( comp, _check_input_parameter_used_via_selection ) - unbound_refs = transformation_utils.get_map_of_unbound_references(comp) + unbound_refs = federated_language.framework.get_map_of_unbound_references( + comp + ) top_level_unbound_refs = unbound_refs[comp] # Maintain a map where the keys are the computations that should be passed to @@ -899,16 +940,16 @@ def _check_input_parameter_used_via_selection(inner_comp): def _is_replacement_candidate(inner_comp): # A replacement is needed if the subtree represents a reference to an top- # level unbound ref. - if isinstance(inner_comp, building_blocks.Reference) and unbound_refs[ - inner_comp - ].issubset(top_level_unbound_refs): + if isinstance( + inner_comp, federated_language.framework.Reference + ) and unbound_refs[inner_comp].issubset(top_level_unbound_refs): return True # A replacement is also needed if the subtree represents a selection into # a top-level unbound ref. We trigger the replacement on selections at this # level so that we can pass the minimal amount of information possible # through the extended input parameter. - if isinstance(inner_comp, building_blocks.Selection): + if isinstance(inner_comp, federated_language.framework.Selection): return _is_replacement_candidate(inner_comp.source) return False @@ -930,10 +971,12 @@ def _compute_new_parameter_elements(inner_comp): # list of new input comps. Use a preorder transformation since it is # important to replace larger subtrees when possible (e.g. replacing an entire # selection subtree vs just replacing the selection source subtree). - transformation_utils.transform_preorder(comp, _compute_new_parameter_elements) + federated_language.framework.transform_preorder( + comp, _compute_new_parameter_elements + ) # Update the comp parameter type to include the new extension. - new_parameter_type = computation_types.StructType( + new_parameter_type = federated_language.StructType( list(comp.type_signature.parameter.items()) # pytype: disable=attribute-error + [( lambda_parameter_extension_name, @@ -947,9 +990,9 @@ def _rebind_unbound_references_to_new_parameter(inner_comp): # selection into the list of new input comps. if _is_replacement_candidate(inner_comp): assert inner_comp in new_input_comps - new_comp = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference( + new_comp = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( comp_parameter_name, new_parameter_type ), # The input param extension will be added at the end. @@ -962,13 +1005,13 @@ def _rebind_unbound_references_to_new_parameter(inner_comp): # Replace selections into the original input parameter with selections into # the extended input parameter to maintain type signature correctness. if ( - isinstance(inner_comp, building_blocks.Selection) - and isinstance(inner_comp, building_blocks.Reference) + isinstance(inner_comp, federated_language.framework.Selection) + and isinstance(inner_comp, federated_language.framework.Reference) and inner_comp.source.name == comp_parameter_name ): return ( - building_blocks.Selection( - building_blocks.Reference( + federated_language.framework.Selection( + federated_language.framework.Reference( comp_parameter_name, new_parameter_type ), # Use the same index as before. @@ -983,10 +1026,10 @@ def _rebind_unbound_references_to_new_parameter(inner_comp): # comps and also update existing selections into the original input parameter. # Use a preorder transformation again to ensure that the new input comps are # used in the correct order. - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( comp.parameter_name, new_parameter_type, - transformation_utils.transform_preorder( + federated_language.framework.transform_preorder( comp.result, _rebind_unbound_references_to_new_parameter )[0], ) @@ -1001,26 +1044,33 @@ class UnavailableRequiredInputsError(ValueError): # Helper function to replace references with a given name with a different # computation. def _replace_references( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, ref_name: str, - replacement: building_blocks.ComputationBuildingBlock, -) -> building_blocks.ComputationBuildingBlock: + replacement: federated_language.framework.ComputationBuildingBlock, +) -> federated_language.framework.ComputationBuildingBlock: def _replace(comp): - if isinstance(comp, building_blocks.Reference) and comp.name == ref_name: + if ( + isinstance(comp, federated_language.framework.Reference) + and comp.name == ref_name + ): return replacement, True return comp, False - return transformation_utils.transform_postorder(comp, _replace)[0] + return federated_language.framework.transform_postorder(comp, _replace)[0] def divisive_force_align_and_split_by_intrinsics( - comp: building_blocks.Lambda, - intrinsic_defs_to_split: Collection[intrinsic_defs.IntrinsicDef], + comp: federated_language.framework.Lambda, + intrinsic_defs_to_split: Collection[ + federated_language.framework.IntrinsicDef + ], before_comp_allowed_original_arg_subparameters: Sequence[Sequence[int]], intrinsic_comp_allowed_original_arg_subparameters: Sequence[Sequence[int]], after_comp_allowed_original_arg_subparameters: Sequence[Sequence[int]], ) -> tuple[ - building_blocks.Lambda, building_blocks.Lambda, building_blocks.Lambda + federated_language.framework.Lambda, + federated_language.framework.Lambda, + federated_language.framework.Lambda, ]: """Divides `comp` into three components (before, intrinsic, after). @@ -1042,7 +1092,8 @@ def divisive_force_align_and_split_by_intrinsics( `intrinsic_defs_to_split`. Under these conditions, this function will return three - `building_blocks.Lambda`s `before`, `intrinsic`, and `after` such that + `federated_language.framework.Lambda`s `before`, `intrinsic`, and `after` such + that `comp` is semantically equivalent to the following expression: ``` (arg -> (let @@ -1098,8 +1149,8 @@ def divisive_force_align_and_split_by_intrinsics( function is guaranteed to find it. Args: - comp: The instance of `building_blocks.Lambda` that serves as the input to - this transformation, as described above. + comp: The instance of `federated_language.framework.Lambda` that serves as + the input to this transformation, as described above. intrinsic_defs_to_split: A list of intrinsics with which to split the computation. before_comp_allowed_original_arg_subparameters: A list of paths describing @@ -1114,8 +1165,9 @@ def divisive_force_align_and_split_by_intrinsics( Returns: A tuple of the form `(before, intrinsic, after)`, where each of `before`, - `intrinsic`, and `after` is a building_blocks.Lambda` instance with a - `building_blocks.Block` result. + `intrinsic`, and `after` is a federated_language.framework.Lambda` instance + with a + `federated_language.framework.Block` result. Details about the inputs and outputs of the three computations as well as the contents of the `intrinsic` comp are specified above. @@ -1151,7 +1203,7 @@ def divisive_force_align_and_split_by_intrinsics( # promised guarantees. ############################### Step 1 ###################################### - if not isinstance(comp, building_blocks.Lambda): + if not isinstance(comp, federated_language.framework.Lambda): raise TypeError('Expected input computation to be a lambda computation.') if not comp.parameter_name or not comp.parameter_type: @@ -1171,9 +1223,9 @@ def divisive_force_align_and_split_by_intrinsics( intrinsic_uris = set( intrinsic_def.uri for intrinsic_def in intrinsic_defs_to_split ) - if not isinstance(comp.result, building_blocks.Block): - raise building_blocks.UnexpectedBlockError( - building_blocks.Block, comp.result + if not isinstance(comp.result, federated_language.framework.Block): + raise federated_language.framework.UnexpectedBlockError( + federated_language.framework.Block, comp.result ) deps = _compute_intrinsic_dependencies( intrinsic_uris, @@ -1184,18 +1236,21 @@ def divisive_force_align_and_split_by_intrinsics( ############################### Step 2 ###################################### # Generate a preliminary intrinsic comp. - intrinsic_locals: list[tuple[str, building_blocks.Call]] = [] + intrinsic_locals: list[tuple[str, federated_language.framework.Call]] = [] for intrinsic_locals_for_uri in deps.uri_to_locals.values(): intrinsic_locals.extend(intrinsic_locals_for_uri) intrinsic_results = [ - building_blocks.Reference(local_name, local_value.type_signature) + federated_language.framework.Reference( + local_name, local_value.type_signature + ) for local_name, local_value in intrinsic_locals ] - preliminary_intrinsic_comp = building_blocks.Lambda( + preliminary_intrinsic_comp = federated_language.framework.Lambda( comp.parameter_name, comp.parameter_type, - building_blocks.Block( - intrinsic_locals, building_blocks.Struct(intrinsic_results) + federated_language.framework.Block( + intrinsic_locals, + federated_language.framework.Struct(intrinsic_results), ), ) @@ -1225,11 +1280,13 @@ def divisive_force_align_and_split_by_intrinsics( len(intrinsic_comp.parameter_type) == len(preliminary_intrinsic_comp.parameter_type) + 1 ) - tree_analysis.check_contains_no_unbound_references(intrinsic_comp) + federated_language.framework.check_contains_no_unbound_references( + intrinsic_comp + ) ############################### Step 4 ###################################### # Generate a preliminary after comp. - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) after_param_name = next(name_generator) after_param_type = [ ('original_arg', comp.parameter_type), @@ -1240,31 +1297,37 @@ def divisive_force_align_and_split_by_intrinsics( ] original_arg_index = 0 intrinsic_results_index = 1 - intrinsic_result_bindings: list[tuple[str, building_blocks.Selection]] = [] + intrinsic_result_bindings: list[ + tuple[str, federated_language.framework.Selection] + ] = [] for i, (local_name, _) in enumerate(intrinsic_locals): intrinsic_result_bindings.append(( local_name, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference(after_param_name, after_param_type), + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + after_param_name, after_param_type + ), index=intrinsic_results_index, ), index=i, ), )) - preliminary_after_comp = building_blocks.Lambda( + preliminary_after_comp = federated_language.framework.Lambda( after_param_name, after_param_type, _replace_references( - building_blocks.Block( + federated_language.framework.Block( intrinsic_result_bindings + deps.locals_not_dependent_on_intrinsics + deps.locals_dependent_on_intrinsics, comp.result.result, ), comp.parameter_name, - building_blocks.Selection( - building_blocks.Reference(after_param_name, after_param_type), + federated_language.framework.Selection( + federated_language.framework.Reference( + after_param_name, after_param_type + ), index=original_arg_index, ), ), @@ -1299,7 +1362,7 @@ def divisive_force_align_and_split_by_intrinsics( preliminary_after_comp, after_param_name, { - (0,): building_blocks.Reference( + (0,): federated_language.framework.Reference( comp.parameter_name, comp.parameter_type ) }, @@ -1313,7 +1376,9 @@ def divisive_force_align_and_split_by_intrinsics( ) # This next version of the after comp should have no unbound references. - tree_analysis.check_contains_no_unbound_references(preliminary_after_comp) + federated_language.framework.check_contains_no_unbound_references( + preliminary_after_comp + ) ############################### Step 6 ###################################### # Generate a preliminary before comp that produces the values required by the @@ -1322,16 +1387,21 @@ def divisive_force_align_and_split_by_intrinsics( before_result = [ ( 'intrinsic_args_from_before_comp', - building_blocks.Struct(intrinsic_args_from_before_comp_values), + federated_language.framework.Struct( + intrinsic_args_from_before_comp_values + ), + ), + ( + 'intermediate_state', + federated_language.framework.Struct(intermediate_state), ), - ('intermediate_state', building_blocks.Struct(intermediate_state)), ] - preliminary_before_comp = building_blocks.Lambda( + preliminary_before_comp = federated_language.framework.Lambda( comp.parameter_name, comp.parameter_type, - building_blocks.Block( + federated_language.framework.Block( deps.locals_not_dependent_on_intrinsics, - building_blocks.Struct(before_result), + federated_language.framework.Struct(before_result), ), ) @@ -1347,7 +1417,9 @@ def divisive_force_align_and_split_by_intrinsics( # If the resulting comp is not valid (i.e. it contains no unbound # references), then a split is not possible and we throw an error. - if not tree_analysis.contains_no_unbound_references(preliminary_before_comp): + if not federated_language.framework.contains_no_unbound_references( + preliminary_before_comp + ): raise UnavailableRequiredInputsError( 'The computation is not splittable given the allowed subparameters.' ) @@ -1380,10 +1452,15 @@ def divisive_force_align_and_split_by_intrinsics( intermediate_state_index_in_before_result ] duplicate_intermediate_state_vals = [ - (None, building_blocks.Reference(local_name, local_value.type_signature)) + ( + None, + federated_language.framework.Reference( + local_name, local_value.type_signature + ), + ) for local_name, local_value in duplicated_locals ] - extended_intermediate_state_vals = building_blocks.Struct( + extended_intermediate_state_vals = federated_language.framework.Struct( structure.to_elements(intermediate_state_vals) + duplicate_intermediate_state_vals ) @@ -1395,15 +1472,15 @@ def divisive_force_align_and_split_by_intrinsics( # Update the before comp to produce the result with the extended intermediate # state. If we have reached this stage, this latest before comp should have # no unbound references. - before_comp = building_blocks.Lambda( + before_comp = federated_language.framework.Lambda( preliminary_before_comp.parameter_name, preliminary_before_comp.parameter_type, - building_blocks.Block( + federated_language.framework.Block( preliminary_before_comp.result.locals, - building_blocks.Struct(before_result_elements), + federated_language.framework.Struct(before_result_elements), ), ) - tree_analysis.check_contains_no_unbound_references(before_comp) + federated_language.framework.check_contains_no_unbound_references(before_comp) # Update the after comp parameter to represent the extended intermediate # state. Also restore the name associated with the intrinsic results portion @@ -1424,13 +1501,13 @@ def divisive_force_align_and_split_by_intrinsics( 'intermediate_state', [e.type_signature for e in extended_intermediate_state_vals], ) - preliminary_after_comp = building_blocks.Lambda( + preliminary_after_comp = federated_language.framework.Lambda( preliminary_after_comp.parameter_name, after_param_type_signature, _replace_references( preliminary_after_comp.result, preliminary_after_comp.parameter_name, - building_blocks.Reference( + federated_language.framework.Reference( preliminary_after_comp.parameter_name, after_param_type_signature ), ), @@ -1440,7 +1517,9 @@ def divisive_force_align_and_split_by_intrinsics( # extended intermediate state. Note when indexing into the extended # intermediate state that it consists of the portion constructed in step 5 # followed by the duplicated locals portion. - block_locals: list[tuple[str, building_blocks.ComputationBuildingBlock]] = [] + block_locals: list[ + tuple[str, federated_language.framework.ComputationBuildingBlock] + ] = [] duplicated_local_names = [local_name for local_name, _ in duplicated_locals] intermediate_state_index = ( len(preliminary_after_comp.type_signature.parameter) - 1 @@ -1450,9 +1529,9 @@ def divisive_force_align_and_split_by_intrinsics( if local_name in duplicated_local_names: block_locals.append(( local_name, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference( + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( preliminary_after_comp.parameter_name, preliminary_after_comp.parameter_type, ), @@ -1466,12 +1545,14 @@ def divisive_force_align_and_split_by_intrinsics( block_locals.append((local_name, local_value)) # Update the after comp to use the de-duplicated block locals. - after_comp = building_blocks.Lambda( + after_comp = federated_language.framework.Lambda( preliminary_after_comp.parameter_name, preliminary_after_comp.parameter_type, - building_blocks.Block(block_locals, preliminary_after_comp.result.result), + federated_language.framework.Block( + block_locals, preliminary_after_comp.result.result + ), ) - tree_analysis.check_contains_no_unbound_references(after_comp) + federated_language.framework.check_contains_no_unbound_references(after_comp) ############################### Step 8 ###################################### # Normalize all of the output computations. @@ -1490,8 +1571,10 @@ def divisive_force_align_and_split_by_intrinsics( # returned in the same order they are computed. expected_intrinsic_comp_result_names: list[str] = [] for intrinsic_local, intrinsic_call in intrinsic_comp.result.locals: - assert isinstance(intrinsic_call, building_blocks.Call) - assert isinstance(intrinsic_call.function, building_blocks.Intrinsic) + assert isinstance(intrinsic_call, federated_language.framework.Call) + assert isinstance( + intrinsic_call.function, federated_language.framework.Intrinsic + ) assert intrinsic_call.function.uri in intrinsic_uris expected_intrinsic_comp_result_names.append(intrinsic_local) actual_intrinsic_comp_result_names = [ diff --git a/tensorflow_federated/python/core/impl/compiler/transformations_test.py b/tensorflow_federated/python/core/impl/compiler/transformations_test.py index f3ae2cbfe8..cd8d7611f4 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformations_test.py +++ b/tensorflow_federated/python/core/impl/compiler/transformations_test.py @@ -15,30 +15,22 @@ from unittest import mock from absl.testing import absltest +import federated_language +from federated_language.proto import computation_pb2 import numpy as np -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.compiler import building_block_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils from tensorflow_federated.python.core.impl.compiler import transformations -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_test_utils class ToCallDominantTest(absltest.TestCase): def assert_compact_representations_equal( self, - actual: building_blocks.ComputationBuildingBlock, - expected: building_blocks.ComputationBuildingBlock, + actual: federated_language.framework.ComputationBuildingBlock, + expected: federated_language.framework.ComputationBuildingBlock, ): """Asserts that two building blocks have the same compact representation.""" self.assertEqual( @@ -46,12 +38,16 @@ def assert_compact_representations_equal( ) def test_inlines_references(self): - int_type = computation_types.TensorType(np.int32) - int_ref = lambda name: building_blocks.Reference(name, int_type) - int_fn = lambda name, result: building_blocks.Lambda(name, int_type, result) + int_type = federated_language.TensorType(np.int32) + int_ref = lambda name: federated_language.framework.Reference( + name, int_type + ) + int_fn = lambda name, result: federated_language.framework.Lambda( + name, int_type, result + ) before = int_fn( 'x', - building_blocks.Block( + federated_language.framework.Block( [ ('y', int_ref('x')), ('z', int_ref('y')), @@ -64,37 +60,39 @@ def test_inlines_references(self): self.assert_compact_representations_equal(after, expected) def test_inlines_selections(self): - int_type = computation_types.TensorType(np.int32) - structed = computation_types.StructType([int_type]) - double = computation_types.StructType([structed]) - before = building_blocks.Lambda( + int_type = federated_language.TensorType(np.int32) + structed = federated_language.StructType([int_type]) + double = federated_language.StructType([structed]) + before = federated_language.framework.Lambda( 'x', double, - building_blocks.Block( + federated_language.framework.Block( [ ( 'y', - building_blocks.Selection( - building_blocks.Reference('x', double), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference('x', double), + index=0, ), ), ( 'z', - building_blocks.Selection( - building_blocks.Reference('y', structed), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference('y', structed), + index=0, ), ), ], - building_blocks.Reference('z', int_type), + federated_language.framework.Reference('z', int_type), ), ) after = transformations.to_call_dominant(before) - expected = building_blocks.Lambda( + expected = federated_language.framework.Lambda( 'x', double, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', double), index=0 + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('x', double), index=0 ), index=0, ), @@ -102,122 +100,137 @@ def test_inlines_selections(self): self.assert_compact_representations_equal(after, expected) def test_inlines_structs(self): - int_type = computation_types.TensorType(np.int32) - structed = computation_types.StructType([int_type]) - double = computation_types.StructType([structed]) - before = building_blocks.Lambda( + int_type = federated_language.TensorType(np.int32) + structed = federated_language.StructType([int_type]) + double = federated_language.StructType([structed]) + before = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Block( + federated_language.framework.Block( [ ( 'y', - building_blocks.Struct( - [building_blocks.Reference('x', int_type)] + federated_language.framework.Struct( + [federated_language.framework.Reference('x', int_type)] ), ), ( 'z', - building_blocks.Struct( - [building_blocks.Reference('y', structed)] + federated_language.framework.Struct( + [federated_language.framework.Reference('y', structed)] ), ), ], - building_blocks.Reference('z', double), + federated_language.framework.Reference('z', double), ), ) after = transformations.to_call_dominant(before) - expected = building_blocks.Lambda( + expected = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Struct( - [building_blocks.Struct([building_blocks.Reference('x', int_type)])] - ), + federated_language.framework.Struct([ + federated_language.framework.Struct( + [federated_language.framework.Reference('x', int_type)] + ) + ]), ) self.assert_compact_representations_equal(after, expected) def test_inlines_selection_from_struct(self): - int_type = computation_types.TensorType(np.int32) - before = building_blocks.Lambda( + int_type = federated_language.TensorType(np.int32) + before = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Selection( - building_blocks.Struct([building_blocks.Reference('x', int_type)]), + federated_language.framework.Selection( + federated_language.framework.Struct( + [federated_language.framework.Reference('x', int_type)] + ), index=0, ), ) after = transformations.to_call_dominant(before) - expected = building_blocks.Lambda( - 'x', int_type, building_blocks.Reference('x', int_type) + expected = federated_language.framework.Lambda( + 'x', int_type, federated_language.framework.Reference('x', int_type) ) self.assert_compact_representations_equal(after, expected) def test_creates_binding_for_each_call(self): - int_type = computation_types.TensorType(np.int32) - int_to_int_type = computation_types.FunctionType(int_type, int_type) + int_type = federated_language.TensorType(np.int32) + int_to_int_type = federated_language.FunctionType(int_type, int_type) any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - int_to_int_fn = building_blocks.Data(any_proto, int_to_int_type) - before = building_blocks.Lambda( + int_to_int_fn = federated_language.framework.Data( + any_proto, int_to_int_type + ) + before = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Call( + federated_language.framework.Call( int_to_int_fn, - building_blocks.Call( - int_to_int_fn, building_blocks.Reference('x', int_type) + federated_language.framework.Call( + int_to_int_fn, + federated_language.framework.Reference('x', int_type), ), ), ) after = transformations.to_call_dominant(before) - expected = building_blocks.Lambda( + expected = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Block( + federated_language.framework.Block( [ ( '_var1', - building_blocks.Call( - int_to_int_fn, building_blocks.Reference('x', int_type) + federated_language.framework.Call( + int_to_int_fn, + federated_language.framework.Reference('x', int_type), ), ), ( '_var2', - building_blocks.Call( + federated_language.framework.Call( int_to_int_fn, - building_blocks.Reference('_var1', int_type), + federated_language.framework.Reference( + '_var1', int_type + ), ), ), ], - building_blocks.Reference('_var2', int_type), + federated_language.framework.Reference('_var2', int_type), ), ) self.assert_compact_representations_equal(after, expected) def test_evaluates_called_lambdas(self): - int_type = computation_types.TensorType(np.int32) - int_to_int_type = computation_types.FunctionType(int_type, int_type) - int_thunk_type = computation_types.FunctionType(None, int_type) + int_type = federated_language.TensorType(np.int32) + int_to_int_type = federated_language.FunctionType(int_type, int_type) + int_thunk_type = federated_language.FunctionType(None, int_type) any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - int_to_int_fn = building_blocks.Data(any_proto, int_to_int_type) + int_to_int_fn = federated_language.framework.Data( + any_proto, int_to_int_type + ) # -> (let result = ext(x) in (-> result)) # Each call of the outer lambda should create a single binding, with # calls to the inner lambda repeatedly returning references to the binding. - higher_fn = building_blocks.Lambda( + higher_fn = federated_language.framework.Lambda( None, None, - building_blocks.Block( + federated_language.framework.Block( [( 'result', - building_blocks.Call( - int_to_int_fn, building_blocks.Reference('x', int_type) + federated_language.framework.Call( + int_to_int_fn, + federated_language.framework.Reference('x', int_type), ), )], - building_blocks.Lambda( - None, None, building_blocks.Reference('result', int_type) + federated_language.framework.Lambda( + None, + None, + federated_language.framework.Reference('result', int_type), ), ), ) @@ -226,127 +239,144 @@ def test_evaluates_called_lambdas(self): # fn = -> (let result = ext(x) in (-> result)) ( 'get_val1', - building_blocks.Call( - building_blocks.Reference('fn', higher_fn.type_signature) + federated_language.framework.Call( + federated_language.framework.Reference( + 'fn', higher_fn.type_signature + ) ), ), # _var2 = ext(x) # get_val1 = -> _var2 ( 'get_val2', - building_blocks.Call( - building_blocks.Reference('fn', higher_fn.type_signature) + federated_language.framework.Call( + federated_language.framework.Reference( + 'fn', higher_fn.type_signature + ) ), ), # _var3 = ext(x) # get_val2 = -> _var3 ( 'val11', - building_blocks.Call( - building_blocks.Reference('get_val1', int_thunk_type) + federated_language.framework.Call( + federated_language.framework.Reference( + 'get_val1', int_thunk_type + ) ), ), # val11 = _var2 ( 'val12', - building_blocks.Call( - building_blocks.Reference('get_val1', int_thunk_type) + federated_language.framework.Call( + federated_language.framework.Reference( + 'get_val1', int_thunk_type + ) ), ), # val12 = _var2 ( 'val2', - building_blocks.Call( - building_blocks.Reference('get_val2', int_thunk_type) + federated_language.framework.Call( + federated_language.framework.Reference( + 'get_val2', int_thunk_type + ) ), ), # val2 = _var3 ] - before = building_blocks.Lambda( + before = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, # <_var2, _var2, _var3> - building_blocks.Struct([ - building_blocks.Reference('val11', int_type), - building_blocks.Reference('val12', int_type), - building_blocks.Reference('val2', int_type), + federated_language.framework.Struct([ + federated_language.framework.Reference('val11', int_type), + federated_language.framework.Reference('val12', int_type), + federated_language.framework.Reference('val2', int_type), ]), ), ) after = transformations.to_call_dominant(before) - expected = building_blocks.Lambda( + expected = federated_language.framework.Lambda( 'x', int_type, - building_blocks.Block( + federated_language.framework.Block( [ ( '_var2', - building_blocks.Call( - int_to_int_fn, building_blocks.Reference('x', int_type) + federated_language.framework.Call( + int_to_int_fn, + federated_language.framework.Reference('x', int_type), ), ), ( '_var3', - building_blocks.Call( - int_to_int_fn, building_blocks.Reference('x', int_type) + federated_language.framework.Call( + int_to_int_fn, + federated_language.framework.Reference('x', int_type), ), ), ], - building_blocks.Struct([ - building_blocks.Reference('_var2', int_type), - building_blocks.Reference('_var2', int_type), - building_blocks.Reference('_var3', int_type), + federated_language.framework.Struct([ + federated_language.framework.Reference('_var2', int_type), + federated_language.framework.Reference('_var2', int_type), + federated_language.framework.Reference('_var3', int_type), ]), ), ) self.assert_compact_representations_equal(after, expected) def test_creates_block_for_non_lambda(self): - int_type = computation_types.TensorType(np.int32) - two_int_type = computation_types.StructType( + int_type = federated_language.TensorType(np.int32) + two_int_type = federated_language.StructType( [(None, int_type), (None, int_type)] ) - get_two_int_type = computation_types.FunctionType(None, two_int_type) + get_two_int_type = federated_language.FunctionType(None, two_int_type) any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - call_ext = building_blocks.Call( - building_blocks.Data(any_proto, get_two_int_type) + call_ext = federated_language.framework.Call( + federated_language.framework.Data(any_proto, get_two_int_type) ) - before = building_blocks.Selection(call_ext, index=0) + before = federated_language.framework.Selection(call_ext, index=0) after = transformations.to_call_dominant(before) - expected = building_blocks.Block( + expected = federated_language.framework.Block( [ ('_var1', call_ext), ], - building_blocks.Selection( - building_blocks.Reference('_var1', two_int_type), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference('_var1', two_int_type), + index=0, ), ) self.assert_compact_representations_equal(after, expected) def test_call_to_higher_order_external_allowed(self): - int_type = computation_types.TensorType(np.int32) - int_to_int_type = computation_types.FunctionType(int_type, int_type) - int_to_int_to_int_type = computation_types.FunctionType( + int_type = federated_language.TensorType(np.int32) + int_to_int_type = federated_language.FunctionType(int_type, int_type) + int_to_int_to_int_type = federated_language.FunctionType( int_to_int_type, int_type ) - call_ext = building_blocks.Call( - building_blocks.Reference('call_with_one', int_to_int_to_int_type), - building_blocks.Lambda( - 'x', int_type, building_blocks.Reference('num', int_type) + call_ext = federated_language.framework.Call( + federated_language.framework.Reference( + 'call_with_one', int_to_int_to_int_type + ), + federated_language.framework.Lambda( + 'x', + int_type, + federated_language.framework.Reference('num', int_type), ), ) after = transformations.to_call_dominant(call_ext) - self.assertIsInstance(after, building_blocks.Block) + self.assertIsInstance(after, federated_language.framework.Block) self.assertLen(after.locals, 1) (ref_name, bound_call) = after.locals[0] self.assertEqual( bound_call.compact_representation(), call_ext.compact_representation() ) - expected_result = building_blocks.Reference( + expected_result = federated_language.framework.Reference( ref_name, call_ext.type_signature ) self.assert_compact_representations_equal(after.result, expected_result) @@ -365,20 +395,26 @@ def assert_splits_on(self, comp, calls): # Ensure that the resulting computations no longer contain the split # intrinsics. - self.assertFalse(tree_analysis.contains_called_intrinsic(before, uris)) - self.assertFalse(tree_analysis.contains_called_intrinsic(after, uris)) + self.assertFalse( + federated_language.framework.contains_called_intrinsic(before, uris) + ) + self.assertFalse( + federated_language.framework.contains_called_intrinsic(after, uris) + ) # Removal isn't interesting to test for if it wasn't there to begin with. - self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uris)) + self.assertTrue( + federated_language.framework.contains_called_intrinsic(comp, uris) + ) if comp.parameter_type is not None: - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( comp.parameter_type, before.parameter_type ) else: self.assertIsNone(before.parameter_type) # There must be one parameter for each intrinsic in `calls`. self.assertIsInstance( - before.type_signature.result, computation_types.StructType + before.type_signature.result, federated_language.StructType ) self.assertLen(before.type_signature.result, len(calls)) @@ -387,10 +423,10 @@ def assert_splits_on(self, comp, calls): # 'original_arg': comp.parameter_type, (if present) # 'intrinsic_results': [...], # } - self.assertIsInstance(after.parameter_type, computation_types.StructType) + self.assertIsInstance(after.parameter_type, federated_language.StructType) if comp.parameter_type is not None: self.assertLen(after.parameter_type, 2) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( comp.parameter_type, after.parameter_type.original_arg ) else: @@ -401,45 +437,51 @@ def assert_splits_on(self, comp, calls): # Check that each pair of (param, result) is a valid type substitution # for the intrinsic in question. for i in range(len(calls)): - concrete_signature = computation_types.FunctionType( + concrete_signature = federated_language.FunctionType( before.type_signature.result[i], after.parameter_type.intrinsic_results[i], ) abstract_signature = calls[i].function.intrinsic_def().type_signature - type_analysis.check_concrete_instance_of( + federated_language.framework.check_concrete_instance_of( concrete_signature, abstract_signature ) def test_cannot_split_on_chained_intrinsic(self): - int_type = computation_types.TensorType(np.int32) - client_int_type = computation_types.FederatedType( - int_type, placements.CLIENTS + int_type = federated_language.TensorType(np.int32) + client_int_type = federated_language.FederatedType( + int_type, federated_language.CLIENTS + ) + int_ref = lambda name: federated_language.framework.Reference( + name, int_type ) - int_ref = lambda name: building_blocks.Reference(name, int_type) def client_int_ref(name): - return building_blocks.Reference(name, client_int_type) + return federated_language.framework.Reference(name, client_int_type) - body = building_blocks.Block( + body = federated_language.framework.Block( [ ( 'a', - building_block_factory.create_federated_map( - building_blocks.Lambda('p1', int_type, int_ref('p1')), + federated_language.framework.create_federated_map( + federated_language.framework.Lambda( + 'p1', int_type, int_ref('p1') + ), client_int_ref('param'), ), ), ( 'b', - building_block_factory.create_federated_map( - building_blocks.Lambda('p2', int_type, int_ref('p2')), + federated_language.framework.create_federated_map( + federated_language.framework.Lambda( + 'p2', int_type, int_ref('p2') + ), client_int_ref('a'), ), ), ], client_int_ref('b'), ) - comp = building_blocks.Lambda('param', client_int_type, body) + comp = federated_language.framework.Lambda('param', client_int_type, body) intrinsic_defaults = [ building_block_test_utils.create_whimsy_called_federated_map('test'), ] @@ -452,8 +494,10 @@ def test_splits_on_intrinsic_noarg_function(self): federated_broadcast = ( building_block_test_utils.create_whimsy_called_federated_broadcast() ) - called_intrinsics = building_blocks.Struct([federated_broadcast]) - comp = building_blocks.Lambda(None, None, called_intrinsics) + called_intrinsics = federated_language.framework.Struct( + [federated_broadcast] + ) + comp = federated_language.framework.Lambda(None, None, called_intrinsics) call = building_block_test_utils.create_whimsy_called_federated_broadcast() self.assert_splits_on(comp, call) @@ -461,8 +505,10 @@ def test_splits_on_selected_intrinsic_broadcast(self): federated_broadcast = ( building_block_test_utils.create_whimsy_called_federated_broadcast() ) - called_intrinsics = building_blocks.Struct([federated_broadcast]) - comp = building_blocks.Lambda('a', np.int32, called_intrinsics) + called_intrinsics = federated_language.framework.Struct( + [federated_broadcast] + ) + comp = federated_language.framework.Lambda('a', np.int32, called_intrinsics) call = building_block_test_utils.create_whimsy_called_federated_broadcast() self.assert_splits_on(comp, call) @@ -473,17 +519,21 @@ def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self): any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - packed_broadcast = building_blocks.Struct([ - building_blocks.Data( + packed_broadcast = federated_language.framework.Struct([ + federated_language.framework.Data( any_proto, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ), first_broadcast, ]) - sel = building_blocks.Selection(packed_broadcast, index=0) - second_broadcast = building_block_factory.create_federated_broadcast(sel) + sel = federated_language.framework.Selection(packed_broadcast, index=0) + second_broadcast = federated_language.framework.create_federated_broadcast( + sel + ) result = transformations.to_call_dominant(second_broadcast) - comp = building_blocks.Lambda('a', np.int32, result) + comp = federated_language.framework.Lambda('a', np.int32, result) call = building_block_test_utils.create_whimsy_called_federated_broadcast() self.assert_splits_on(comp, call) @@ -491,11 +541,11 @@ def test_splits_on_multiple_of_selected_intrinsic_broadcast(self): federated_broadcast = ( building_block_test_utils.create_whimsy_called_federated_broadcast() ) - called_intrinsics = building_blocks.Struct([ + called_intrinsics = federated_language.framework.Struct([ federated_broadcast, federated_broadcast, ]) - comp = building_blocks.Lambda('a', np.int32, called_intrinsics) + comp = federated_language.framework.Lambda('a', np.int32, called_intrinsics) call = building_block_test_utils.create_whimsy_called_federated_broadcast() self.assert_splits_on(comp, call) @@ -507,8 +557,10 @@ def test_splits_on_selected_intrinsic_aggregate(self): report_parameter_name='c', ) ) - called_intrinsics = building_blocks.Struct([federated_aggregate]) - comp = building_blocks.Lambda('d', np.int32, called_intrinsics) + called_intrinsics = federated_language.framework.Struct( + [federated_aggregate] + ) + comp = federated_language.framework.Lambda('d', np.int32, called_intrinsics) call = building_block_test_utils.create_whimsy_called_federated_aggregate( value_type=np.int32 ) @@ -522,11 +574,11 @@ def test_splits_on_multiple_of_selected_intrinsic_aggregate(self): report_parameter_name='c', ) ) - called_intrinsics = building_blocks.Struct([ + called_intrinsics = federated_language.framework.Struct([ federated_aggregate, federated_aggregate, ]) - comp = building_blocks.Lambda('d', np.int32, called_intrinsics) + comp = federated_language.framework.Lambda('d', np.int32, called_intrinsics) call = building_block_test_utils.create_whimsy_called_federated_aggregate() self.assert_splits_on(comp, call) @@ -534,8 +586,10 @@ def test_splits_on_selected_intrinsic_secure_sum_bitwidth(self): federated_secure_sum_bitwidth = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) - called_intrinsics = building_blocks.Struct([federated_secure_sum_bitwidth]) - comp = building_blocks.Lambda('a', np.int32, called_intrinsics) + called_intrinsics = federated_language.framework.Struct( + [federated_secure_sum_bitwidth] + ) + comp = federated_language.framework.Lambda('a', np.int32, called_intrinsics) call = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) @@ -545,11 +599,11 @@ def test_splits_on_multiple_of_selected_intrinsic_secure_sum_bitwidths(self): federated_secure_sum_bitwidth = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) - called_intrinsics = building_blocks.Struct([ + called_intrinsics = federated_language.framework.Struct([ federated_secure_sum_bitwidth, federated_secure_sum_bitwidth, ]) - comp = building_blocks.Lambda('a', np.int32, called_intrinsics) + comp = federated_language.framework.Lambda('a', np.int32, called_intrinsics) call = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) @@ -566,11 +620,11 @@ def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self): federated_secure_sum_bitwidth = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) - called_intrinsics = building_blocks.Struct([ + called_intrinsics = federated_language.framework.Struct([ federated_aggregate, federated_secure_sum_bitwidth, ]) - comp = building_blocks.Lambda('d', np.int32, called_intrinsics) + comp = federated_language.framework.Lambda('d', np.int32, called_intrinsics) null_aggregate = ( building_block_test_utils.create_whimsy_called_federated_aggregate() ) @@ -580,20 +634,30 @@ def test_removes_selected_intrinsic_leaving_remaining_intrinsic(self): comp, [null_aggregate] ) self.assertTrue( - tree_analysis.contains_called_intrinsic(comp, secure_sum_bitwidth_uri) + federated_language.framework.contains_called_intrinsic( + comp, secure_sum_bitwidth_uri + ) ) self.assertTrue( - tree_analysis.contains_called_intrinsic(comp, aggregate_uri) + federated_language.framework.contains_called_intrinsic( + comp, aggregate_uri + ) ) self.assertFalse( - tree_analysis.contains_called_intrinsic(before, aggregate_uri) + federated_language.framework.contains_called_intrinsic( + before, aggregate_uri + ) ) self.assertFalse( - tree_analysis.contains_called_intrinsic(after, aggregate_uri) + federated_language.framework.contains_called_intrinsic( + after, aggregate_uri + ) ) self.assertTrue( - tree_analysis.contains_called_intrinsic(before, secure_sum_bitwidth_uri) - or tree_analysis.contains_called_intrinsic( + federated_language.framework.contains_called_intrinsic( + before, secure_sum_bitwidth_uri + ) + or federated_language.framework.contains_called_intrinsic( after, secure_sum_bitwidth_uri ) ) @@ -609,11 +673,11 @@ def test_splits_on_two_intrinsics(self): federated_secure_sum_bitwidth = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) - called_intrinsics = building_blocks.Struct([ + called_intrinsics = federated_language.framework.Struct([ federated_aggregate, federated_secure_sum_bitwidth, ]) - comp = building_blocks.Lambda('d', np.int32, called_intrinsics) + comp = federated_language.framework.Lambda('d', np.int32, called_intrinsics) self.assert_splits_on( comp, [ @@ -633,13 +697,13 @@ def test_splits_on_multiple_instances_of_two_intrinsics(self): federated_secure_sum_bitwidth = ( building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() ) - called_intrinsics = building_blocks.Struct([ + called_intrinsics = federated_language.framework.Struct([ federated_aggregate, federated_aggregate, federated_secure_sum_bitwidth, federated_secure_sum_bitwidth, ]) - comp = building_blocks.Lambda('d', np.int32, called_intrinsics) + comp = federated_language.framework.Lambda('d', np.int32, called_intrinsics) self.assert_splits_on( comp, [ @@ -656,8 +720,10 @@ def test_splits_even_when_selected_intrinsic_is_not_present(self): report_parameter_name='c', ) ) - called_intrinsics = building_blocks.Struct([federated_aggregate]) - comp = building_blocks.Lambda('d', np.int32, called_intrinsics) + called_intrinsics = federated_language.framework.Struct( + [federated_aggregate] + ) + comp = federated_language.framework.Lambda('d', np.int32, called_intrinsics) transformations.force_align_and_split_by_intrinsics( comp, [ @@ -671,11 +737,11 @@ class AugmentLambdaWithParameterForUnboundReferences(absltest.TestCase): def _check_transformed_comp_validity( self, - original_comp: building_blocks.Lambda, - transformed_comp: building_blocks.ComputationBuildingBlock, + original_comp: federated_language.framework.Lambda, + transformed_comp: federated_language.framework.ComputationBuildingBlock, lambda_parameter_extension_name: str, ): - self.assertIsInstance(transformed_comp, building_blocks.Lambda) + self.assertIsInstance(transformed_comp, federated_language.framework.Lambda) # The transformed lambda comp should have an additional element in the input # parameter named `lambda_parameter_extension_name`. @@ -691,35 +757,41 @@ def _check_transformed_comp_validity( # The transformed lambda comp should have no unbound references. self.assertEmpty( - transformation_utils.get_map_of_unbound_references(transformed_comp)[ + federated_language.framework.get_map_of_unbound_references( transformed_comp - ] + )[transformed_comp] ) def test_identifies_unbound_refs(self): - original_arg_type = computation_types.StructType([np.int32]) - int_at_clients_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + original_arg_type = federated_language.StructType([np.int32]) + int_at_clients_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', original_arg_type, - building_blocks.Block( + federated_language.framework.Block( [( 'a', - building_block_factory.create_federated_sum( - building_blocks.Reference('x', int_at_clients_type) + federated_language.framework.create_federated_sum( + federated_language.framework.Reference( + 'x', int_at_clients_type + ) ), )], - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'a', - computation_types.FederatedType( - np.int32, placements.SERVER + federated_language.FederatedType( + np.int32, federated_language.SERVER ), ), - building_blocks.Reference('y', int_at_clients_type), - building_blocks.Reference('x', int_at_clients_type), + federated_language.framework.Reference( + 'y', int_at_clients_type + ), + federated_language.framework.Reference( + 'x', int_at_clients_type + ), ]), ), ) @@ -737,45 +809,46 @@ def test_identifies_unbound_refs(self): for new_input_comp, expected_new_input_comp in zip( new_input_comps, [ - building_blocks.Reference('x', int_at_clients_type), - building_blocks.Reference('y', int_at_clients_type), - building_blocks.Reference('x', int_at_clients_type), + federated_language.framework.Reference('x', int_at_clients_type), + federated_language.framework.Reference('y', int_at_clients_type), + federated_language.framework.Reference('x', int_at_clients_type), ], ): self.assertEqual(new_input_comp.proto, expected_new_input_comp.proto) def test_identifies_unbound_selections(self): - original_arg_type = computation_types.StructType([np.int32]) - int_at_clients_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + original_arg_type = federated_language.StructType([np.int32]) + int_at_clients_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - federated_sum_param = building_blocks.Selection( - building_blocks.Reference('x', [int_at_clients_type]), index=0 + federated_sum_param = federated_language.framework.Selection( + federated_language.framework.Reference('x', [int_at_clients_type]), + index=0, ) - other_result_param = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference( + other_result_param = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( 'y', [[int_at_clients_type, int_at_clients_type]] ), index=0, ), index=1, ) - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', original_arg_type, - building_blocks.Block( + federated_language.framework.Block( [( 'a', - building_block_factory.create_federated_sum( + federated_language.framework.create_federated_sum( federated_sum_param ), )], - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'a', - computation_types.FederatedType( - np.int32, placements.SERVER + federated_language.FederatedType( + np.int32, federated_language.SERVER ), ), other_result_param, @@ -799,30 +872,32 @@ def test_identifies_unbound_selections(self): self.assertEqual(new_input_comp.proto, expected_new_input_comp.proto) def test_identifies_unbound_refs_in_struct(self): - original_arg_type = computation_types.StructType( - [computation_types.FederatedType(np.int32, placements.CLIENTS)] + original_arg_type = federated_language.StructType( + [federated_language.FederatedType(np.int32, federated_language.CLIENTS)] ) - int_at_clients_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + int_at_clients_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', original_arg_type, - building_blocks.Block( + federated_language.framework.Block( [( 'a', - building_block_factory.create_federated_sum( - building_block_factory.create_federated_zip( - building_blocks.Struct([ - building_blocks.Selection( - building_blocks.Reference( + federated_language.framework.create_federated_sum( + federated_language.framework.create_federated_zip( + federated_language.framework.Struct([ + federated_language.framework.Selection( + federated_language.framework.Reference( 'arg', original_arg_type ), index=0, ), - building_blocks.Reference('b', int_at_clients_type), - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Reference( + 'b', int_at_clients_type + ), + federated_language.framework.Struct([ + federated_language.framework.Reference( 'c', int_at_clients_type ) ]), @@ -830,10 +905,10 @@ def test_identifies_unbound_refs_in_struct(self): ) ), )], - building_blocks.Reference( + federated_language.framework.Reference( 'a', - computation_types.FederatedType( - [np.int32, np.int32, [np.int32]], placements.SERVER + federated_language.FederatedType( + [np.int32, np.int32, [np.int32]], federated_language.SERVER ), ), ), @@ -852,19 +927,20 @@ def test_identifies_unbound_refs_in_struct(self): for new_input_comp, expected_new_input_comp in zip( new_input_comps, [ - building_blocks.Reference('b', int_at_clients_type), - building_blocks.Reference('c', int_at_clients_type), + federated_language.framework.Reference('b', int_at_clients_type), + federated_language.framework.Reference('c', int_at_clients_type), ], ): self.assertEqual(new_input_comp.proto, expected_new_input_comp.proto) def test_no_unbound_refs(self): - original_arg_type = computation_types.StructType([np.int32]) - comp = building_blocks.Lambda( + original_arg_type = federated_language.StructType([np.int32]) + comp = federated_language.framework.Lambda( 'arg', original_arg_type, - building_blocks.Selection( - building_blocks.Reference('arg', original_arg_type), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference('arg', original_arg_type), + index=0, ), ) lambda_parameter_extension_name = 'intermediate_state' @@ -880,11 +956,11 @@ def test_no_unbound_refs(self): self.assertEmpty(new_input_comps) def test_parameter_usage_without_selection(self): - original_arg_type = computation_types.StructType([np.int32]) - comp = building_blocks.Lambda( + original_arg_type = federated_language.StructType([np.int32]) + comp = federated_language.framework.Lambda( 'arg', original_arg_type, - building_blocks.Reference('arg', original_arg_type), + federated_language.framework.Reference('arg', original_arg_type), ) lambda_parameter_extension_name = 'intermediate_state' with self.assertRaises(ValueError): @@ -900,21 +976,23 @@ def find_intrinsics_in_comp(self, comp): def _find_intrinsics(building_block): nonlocal found_intrinsics - if isinstance(building_block, building_blocks.Call) and isinstance( - building_block.function, building_blocks.Intrinsic + if isinstance( + building_block, federated_language.framework.Call + ) and isinstance( + building_block.function, federated_language.framework.Intrinsic ): found_intrinsics.append(building_block.function.uri) - tree_analysis.visit_postorder(comp, _find_intrinsics) + federated_language.framework.visit_postorder(comp, _find_intrinsics) return found_intrinsics def check_split_signatures(self, original_comp, before, intrinsic, after): for comp in (before, intrinsic, after): - self.assertIsInstance(comp, building_blocks.Lambda) + self.assertIsInstance(comp, federated_language.framework.Lambda) self.assertIsInstance( - comp.type_signature.parameter, computation_types.StructType + comp.type_signature.parameter, federated_language.StructType ) - self.assertIsInstance(comp.result, building_blocks.Block) + self.assertIsInstance(comp.result, federated_language.framework.Block) original_comp = transformations.to_call_dominant(original_comp) original_comp = tree_transformations.normalize_types( @@ -922,14 +1000,18 @@ def check_split_signatures(self, original_comp, before, intrinsic, after): ) self.assertIsInstance( - before.type_signature.result, computation_types.StructType + before.type_signature.result, federated_language.StructType ) self.assertEqual( [x for x, _ in structure.to_elements(before.type_signature.result)], ['intrinsic_args_from_before_comp', 'intermediate_state'], ) - self.assertIsInstance(before.result.result[0], building_blocks.Struct) - self.assertIsInstance(before.result.result[1], building_blocks.Struct) + self.assertIsInstance( + before.result.result[0], federated_language.framework.Struct + ) + self.assertIsInstance( + before.result.result[1], federated_language.framework.Struct + ) intrinsic_args_from_before_comp_index = ( len(intrinsic.type_signature.parameter) - 1 @@ -948,7 +1030,7 @@ def check_split_signatures(self, original_comp, before, intrinsic, after): before.type_signature.result[0], ) self.assertIsInstance( - intrinsic.type_signature.result, computation_types.StructType + intrinsic.type_signature.result, federated_language.StructType ) self.assertLen( intrinsic.result.locals, len(intrinsic.type_signature.result) @@ -979,20 +1061,21 @@ def check_split_signatures(self, original_comp, before, intrinsic, after): ) def test_splits_on_intrinsic(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 - intrinsic_call = building_block_factory.create_federated_broadcast( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=server_data_index + intrinsic_call = federated_language.framework.create_federated_broadcast( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=server_data_index, ) ) - comp = building_blocks.Lambda('arg', arg_type, intrinsic_call) + comp = federated_language.framework.Lambda('arg', arg_type, intrinsic_call) # Allow the before and after comps to depend on the entire original comp # input. Do not allow the intrinsic comp to depend on the original comp @@ -1000,7 +1083,7 @@ def test_splits_on_intrinsic(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[], after_comp_allowed_original_arg_subparameters=[()], @@ -1014,7 +1097,7 @@ def test_splits_on_intrinsic(self): self.assertEmpty(self.find_intrinsics_in_comp(before)) self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_BROADCAST.uri], + [federated_language.framework.FEDERATED_BROADCAST.uri], ) self.assertEmpty(self.find_intrinsics_in_comp(after)) @@ -1022,20 +1105,21 @@ def test_splits_on_intrinsic(self): self.assertLen(before.result.result[0], 1) def test_fails_split_with_unavailable_subparameters_to_before_comp(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 - intrinsic_call = building_block_factory.create_federated_broadcast( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=server_data_index + intrinsic_call = federated_language.framework.create_federated_broadcast( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=server_data_index, ) ) - comp = building_blocks.Lambda('arg', arg_type, intrinsic_call) + comp = federated_language.framework.Lambda('arg', arg_type, intrinsic_call) # Do not allow the before or intrinsic comps to depend on the original comp # input at all. This should fail when the before comp attempts to produce @@ -1043,30 +1127,32 @@ def test_fails_split_with_unavailable_subparameters_to_before_comp(self): with self.assertRaises(transformations.UnavailableRequiredInputsError): transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[], intrinsic_comp_allowed_original_arg_subparameters=[], after_comp_allowed_original_arg_subparameters=[()], ) def test_splits_on_intrinsic_with_multiple_args(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] client_data_index = 1 - intrinsic_call = building_block_factory.create_federated_mean( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=client_data_index + intrinsic_call = federated_language.framework.create_federated_mean( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=client_data_index, ), - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=client_data_index + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=client_data_index, ), ) - comp = building_blocks.Lambda('arg', arg_type, intrinsic_call) + comp = federated_language.framework.Lambda('arg', arg_type, intrinsic_call) # Allow the before comp to depend on the client portion of the original # comp input. Don't allow the intrinsic comp to depend on the original comp @@ -1074,10 +1160,10 @@ def test_splits_on_intrinsic_with_multiple_args(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_WEIGHTED_MEAN], - before_comp_allowed_original_arg_subparameters=[ - (client_data_index,) - ], + [federated_language.framework.FEDERATED_WEIGHTED_MEAN], + before_comp_allowed_original_arg_subparameters=[( + client_data_index, + )], intrinsic_comp_allowed_original_arg_subparameters=[], after_comp_allowed_original_arg_subparameters=[()], ) @@ -1090,7 +1176,7 @@ def test_splits_on_intrinsic_with_multiple_args(self): self.assertEmpty(self.find_intrinsics_in_comp(before)) self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri], + [federated_language.framework.FEDERATED_WEIGHTED_MEAN.uri], ) self.assertEmpty(self.find_intrinsics_in_comp(after)) @@ -1100,16 +1186,17 @@ def test_splits_on_intrinsic_with_multiple_args(self): # two args. self.assertLen(intrinsic.result.locals, 1) self.assertIsInstance( - intrinsic.result.locals[0][1].argument, building_blocks.Struct + intrinsic.result.locals[0][1].argument, + federated_language.framework.Struct, ) self.assertLen(intrinsic.result.locals[0][1].argument, 2) def test_splits_on_intrinsic_with_args_from_original_arg(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) intermediate_state_type = [np.int32] server_data_index = 0 @@ -1120,19 +1207,20 @@ def test_splits_on_intrinsic_with_args_from_original_arg(self): (None, client_val_type), ('intermediate_state', intermediate_state_type), ] - intrinsic_call = building_block_factory.create_federated_secure_sum( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=client_data_index + intrinsic_call = federated_language.framework.create_federated_secure_sum( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=client_data_index, ), - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=intermediate_state_index, ), index=0, ), ) - comp = building_blocks.Lambda('arg', arg_type, intrinsic_call) + comp = federated_language.framework.Lambda('arg', arg_type, intrinsic_call) # Allow the before comp to depend on the client portion of the original # comp input and the intrinsic comp to depend on the intermediate @@ -1141,7 +1229,7 @@ def test_splits_on_intrinsic_with_args_from_original_arg(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_SECURE_SUM], + [federated_language.framework.FEDERATED_SECURE_SUM], before_comp_allowed_original_arg_subparameters=[ (client_data_index,), ], @@ -1162,7 +1250,7 @@ def test_splits_on_intrinsic_with_args_from_original_arg(self): self.assertEmpty(self.find_intrinsics_in_comp(before)) self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_SECURE_SUM.uri], + [federated_language.framework.FEDERATED_SECURE_SUM.uri], ) self.assertEmpty(self.find_intrinsics_in_comp(after)) @@ -1171,74 +1259,77 @@ def test_splits_on_intrinsic_with_args_from_original_arg(self): self.assertLen(before.result.result[0], 1) self.assertLen(intrinsic.result.locals, 1) self.assertIsInstance( - intrinsic.result.locals[0][1].argument, building_blocks.Struct + intrinsic.result.locals[0][1].argument, + federated_language.framework.Struct, ) self.assertLen(intrinsic.result.locals[0][1].argument, 2) def test_splits_with_non_empty_before_and_after_block_comps(self): - server_val_type = computation_types.FederatedType( - [np.int32, np.int32], placements.SERVER + server_val_type = federated_language.FederatedType( + [np.int32, np.int32], federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] - inner_server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + inner_server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) server_data_index = 0 - selecting_function = building_blocks.Lambda( + selecting_function = federated_language.framework.Lambda( 'inner_arg', server_val_type.member, - building_blocks.Selection( - building_blocks.Reference('inner_arg', server_val_type.member), + federated_language.framework.Selection( + federated_language.framework.Reference( + 'inner_arg', server_val_type.member + ), index=0, ), ) block_locals = [ ( 'inner_server_data_selection', - building_block_factory.create_federated_map_or_apply( + federated_language.framework.create_federated_map_or_apply( selecting_function, - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ), ), ), ( 'broadcast_result', - building_block_factory.create_federated_broadcast( - building_blocks.Reference( + federated_language.framework.create_federated_broadcast( + federated_language.framework.Reference( 'inner_server_data_selection', inner_server_val_type ) ), ), ( 'another_inner_server_data_selection', - building_block_factory.create_federated_map_or_apply( + federated_language.framework.create_federated_map_or_apply( selecting_function, - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ), ), ), ] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'broadcast_result', - computation_types.FederatedType( - np.int32, placements.CLIENTS + federated_language.FederatedType( + np.int32, federated_language.CLIENTS ), ), - building_blocks.Reference( + federated_language.framework.Reference( 'another_inner_server_data_selection', inner_server_val_type, ), @@ -1252,10 +1343,10 @@ def test_splits_with_non_empty_before_and_after_block_comps(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], - before_comp_allowed_original_arg_subparameters=[ - (server_data_index,) - ], + [federated_language.framework.FEDERATED_BROADCAST], + before_comp_allowed_original_arg_subparameters=[( + server_data_index, + )], intrinsic_comp_allowed_original_arg_subparameters=[], after_comp_allowed_original_arg_subparameters=[()], ) @@ -1268,15 +1359,15 @@ def test_splits_with_non_empty_before_and_after_block_comps(self): # intrinsic. self.assertEqual( self.find_intrinsics_in_comp(before), - [intrinsic_defs.FEDERATED_APPLY.uri], + [federated_language.framework.FEDERATED_APPLY.uri], ) self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_BROADCAST.uri], + [federated_language.framework.FEDERATED_BROADCAST.uri], ) self.assertEqual( self.find_intrinsics_in_comp(after), - [intrinsic_defs.FEDERATED_APPLY.uri], + [federated_language.framework.FEDERATED_APPLY.uri], ) # Check that the before and after comps have blocks with at least one local. @@ -1284,21 +1375,21 @@ def test_splits_with_non_empty_before_and_after_block_comps(self): self.assertNotEmpty(after.result.locals) def test_splits_with_no_matching_intrinsics(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] client_data_index = 1 - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_block_factory.create_federated_sum( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.create_federated_sum( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=client_data_index, ) ), @@ -1309,7 +1400,7 @@ def test_splits_with_no_matching_intrinsics(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], after_comp_allowed_original_arg_subparameters=[()], @@ -1324,27 +1415,28 @@ def test_splits_with_no_matching_intrinsics(self): self.assertEmpty(self.find_intrinsics_in_comp(intrinsic)) self.assertEqual( self.find_intrinsics_in_comp(after), - [intrinsic_defs.FEDERATED_SUM.uri], + [federated_language.framework.FEDERATED_SUM.uri], ) # Check that the intermediate state is empty. self.assertEmpty(before.result.result[1]) def test_splits_with_intermediate_state_for_unbound_refs(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=server_data_index + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=server_data_index, ), ) @@ -1355,7 +1447,7 @@ def test_splits_with_intermediate_state_for_unbound_refs(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], after_comp_allowed_original_arg_subparameters=[], @@ -1370,19 +1462,17 @@ def test_splits_with_intermediate_state_for_unbound_refs(self): self.assertNotEmpty(before.result.result[1]) self.assertEqual( before.result.result[1].proto, - building_blocks.Struct( - [ - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference( - before.parameter_name, before.parameter_type - ), - index=0, + federated_language.framework.Struct([ + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + before.parameter_name, before.parameter_type ), - index=server_data_index, - ) - ] - ).proto, + index=0, + ), + index=server_data_index, + ) + ]).proto, ) # Allow all the output comps to depend on all portions of the original comp @@ -1390,7 +1480,7 @@ def test_splits_with_intermediate_state_for_unbound_refs(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], after_comp_allowed_original_arg_subparameters=[()], @@ -1403,11 +1493,11 @@ def test_splits_with_intermediate_state_for_unbound_refs(self): self.assertEmpty(before.result.result[1]) def test_splits_with_intermediate_state_for_duplication(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] @@ -1416,63 +1506,67 @@ def test_splits_with_intermediate_state_for_duplication(self): proto = mock.create_autospec( computation_pb2.Computation, spec_set=True, instance=True ) - function_type = computation_types.FunctionType(None, np.int32) - compiled = building_blocks.CompiledComputation( + function_type = federated_language.FunctionType(None, np.int32) + compiled = federated_language.framework.CompiledComputation( proto, name='state_1', type_signature=function_type ) - state_1 = building_blocks.Call(compiled, None) - server_state_val_call_1 = building_block_factory.create_federated_value( - state_1, - placements.SERVER, + state_1 = federated_language.framework.Call(compiled, None) + server_state_val_call_1 = ( + federated_language.framework.create_federated_value( + state_1, + federated_language.SERVER, + ) ) block_locals.append(('server_state_val_1', server_state_val_call_1)) proto = mock.create_autospec( computation_pb2.Computation, spec_set=True, instance=True ) - function_type = computation_types.FunctionType(None, np.int32) - compiled = building_blocks.CompiledComputation( + function_type = federated_language.FunctionType(None, np.int32) + compiled = federated_language.framework.CompiledComputation( proto, name='state_2', type_signature=function_type ) - state_2 = building_blocks.Call(compiled, None) - server_state_val_call_2 = building_block_factory.create_federated_value( - state_2, - placements.SERVER, + state_2 = federated_language.framework.Call(compiled, None) + server_state_val_call_2 = ( + federated_language.framework.create_federated_value( + state_2, + federated_language.SERVER, + ) ) block_locals.append(('server_state_val_2', server_state_val_call_2)) - broadcast_call_1 = building_block_factory.create_federated_broadcast( - building_blocks.Reference( + broadcast_call_1 = federated_language.framework.create_federated_broadcast( + federated_language.framework.Reference( 'server_state_val_1', server_state_val_call_1.type_signature ) ) block_locals.append(('broadcast_result_1', broadcast_call_1)) - broadcast_call_2 = building_block_factory.create_federated_broadcast( - building_blocks.Reference( + broadcast_call_2 = federated_language.framework.create_federated_broadcast( + federated_language.framework.Reference( 'server_state_val_2', server_state_val_call_2.type_signature ) ) block_locals.append(('broadcast_result_2', broadcast_call_2)) - federated_zip_call = building_block_factory.create_federated_zip( - building_blocks.Struct([ - building_blocks.Reference( + federated_zip_call = federated_language.framework.create_federated_zip( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'broadcast_result_1', broadcast_call_1.type_signature ), - building_blocks.Reference( + federated_language.framework.Reference( 'broadcast_result_2', broadcast_call_2.type_signature ), ]) ) block_locals.append(('federated_zip_result', federated_zip_call)) - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'federated_zip_result', federated_zip_call.type_signature, ), - building_blocks.Reference( + federated_language.framework.Reference( 'server_state_val_1', server_state_val_call_1.type_signature, ), @@ -1485,7 +1579,7 @@ def test_splits_with_intermediate_state_for_duplication(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], after_comp_allowed_original_arg_subparameters=[()], @@ -1500,42 +1594,45 @@ def test_splits_with_intermediate_state_for_duplication(self): self.assertNotEmpty(before.result.result[1]) def _predicate( - building_block: building_blocks.ComputationBuildingBlock, + building_block: federated_language.framework.ComputationBuildingBlock, ) -> bool: return isinstance( - building_block, building_blocks.Reference + building_block, federated_language.framework.Reference ) and isinstance( - building_block.type_signature, computation_types.FederatedType + building_block.type_signature, federated_language.FederatedType ) self.assertEqual( - tree_analysis.count(before.result.result[1], _predicate), 1 + federated_language.framework.computation_count( + before.result.result[1], _predicate + ), + 1, ) # Check that the before comp has two federated_value_at_server intrinsics. self.assertEqual( self.find_intrinsics_in_comp(before), - [intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri] * 2, + [federated_language.framework.FEDERATED_VALUE_AT_SERVER.uri] * 2, ) # Check that the intrinsic comp has two broadcast intrinsics. self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_BROADCAST.uri] * 2, + [federated_language.framework.FEDERATED_BROADCAST.uri] * 2, ) # Check that the after comp only has a federated_zip_at_clients intrinsic. self.assertEqual( self.find_intrinsics_in_comp(after), - [intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri], + [federated_language.framework.FEDERATED_ZIP_AT_CLIENTS.uri], ) def test_splits_with_intermediate_state_for_duplication_and_unbound_refs( self, ): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 @@ -1545,36 +1642,36 @@ def test_splits_with_intermediate_state_for_duplication_and_unbound_refs( proto = mock.create_autospec( computation_pb2.Computation, spec_set=True, instance=True ) - function_type = computation_types.FunctionType(None, np.int32) - compiled = building_blocks.CompiledComputation( + function_type = federated_language.FunctionType(None, np.int32) + compiled = federated_language.framework.CompiledComputation( proto, name='state', type_signature=function_type ) - state = building_blocks.Call(compiled, None) - server_state_val_call = building_block_factory.create_federated_value( + state = federated_language.framework.Call(compiled, None) + server_state_val_call = federated_language.framework.create_federated_value( state, - placements.SERVER, + federated_language.SERVER, ) block_locals.append(('server_state_val', server_state_val_call)) - broadcast_call = building_block_factory.create_federated_broadcast( - building_blocks.Reference( + broadcast_call = federated_language.framework.create_federated_broadcast( + federated_language.framework.Reference( 'server_state_val', server_state_val_call.type_signature ) ) block_locals.append(('broadcast_result', broadcast_call)) - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'broadcast_result', broadcast_call.type_signature ), - building_blocks.Reference( + federated_language.framework.Reference( 'server_state_val', server_state_val_call.type_signature ), - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ), ]), @@ -1587,7 +1684,7 @@ def test_splits_with_intermediate_state_for_duplication_and_unbound_refs( before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], after_comp_allowed_original_arg_subparameters=[], @@ -1603,23 +1700,23 @@ def test_splits_with_intermediate_state_for_duplication_and_unbound_refs( def _server_data_selection_predicate(bb): return ( - isinstance(bb, building_blocks.Selection) + isinstance(bb, federated_language.framework.Selection) and bb.source.name == before.parameter_name ) def _server_state_val_predicate(bb): - return isinstance(bb, building_blocks.Reference) and isinstance( - bb.type_signature, computation_types.FederatedType - ) + return isinstance( + bb, federated_language.framework.Reference + ) and isinstance(bb.type_signature, federated_language.FederatedType) self.assertEqual( - tree_analysis.count( + federated_language.framework.computation_count( before.result.result[1], _server_data_selection_predicate ), 1, ) self.assertEqual( - tree_analysis.count( + federated_language.framework.computation_count( before.result.result[1], _server_state_val_predicate ), 1, @@ -1628,58 +1725,58 @@ def _server_state_val_predicate(bb): # Check that the before comp has only a federated_value_at_server intrinsic. self.assertEqual( self.find_intrinsics_in_comp(before), - [intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri], + [federated_language.framework.FEDERATED_VALUE_AT_SERVER.uri], ) # Check that the intrinsic comp has only a broadcast intrinsics. self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_BROADCAST.uri], + [federated_language.framework.FEDERATED_BROADCAST.uri], ) # Check that the after comp has no intrinsics. self.assertEmpty(self.find_intrinsics_in_comp(after)) def test_splits_on_multiple_instances_of_intrinsic(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 - broadcast_result_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + broadcast_result_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) block_locals = [] block_locals.append(( 'broadcast_result_1', - building_block_factory.create_federated_broadcast( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.create_federated_broadcast( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ) ), )) block_locals.append(( 'broadcast_result_2', - building_block_factory.create_federated_broadcast( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.create_federated_broadcast( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ) ), )) - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'broadcast_result_1', broadcast_result_type ), - building_blocks.Reference( + federated_language.framework.Reference( 'broadcast_result_2', broadcast_result_type ), ]), @@ -1692,7 +1789,7 @@ def test_splits_on_multiple_instances_of_intrinsic(self): before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[], after_comp_allowed_original_arg_subparameters=[()], @@ -1706,7 +1803,7 @@ def test_splits_on_multiple_instances_of_intrinsic(self): self.assertEmpty(self.find_intrinsics_in_comp(before)) self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_BROADCAST.uri] * 2, + [federated_language.framework.FEDERATED_BROADCAST.uri] * 2, ) self.assertEmpty(self.find_intrinsics_in_comp(after)) @@ -1716,11 +1813,11 @@ def test_splits_on_multiple_instances_of_intrinsic(self): self.assertLen(intrinsic.result.locals, 2) def test_splits_on_multiple_intrinsics(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) intermediate_state_type = [np.int32] server_data_index = 0 @@ -1732,20 +1829,21 @@ def test_splits_on_multiple_intrinsics(self): ('intermediate_state', intermediate_state_type), ] - federated_sum_call = building_block_factory.create_federated_sum( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=client_data_index + federated_sum_call = federated_language.framework.create_federated_sum( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=client_data_index, ) ) federated_secure_sum_call = ( - building_block_factory.create_federated_secure_sum( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.create_federated_secure_sum( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=client_data_index, ), - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=intermediate_state_index, ), index=0, @@ -1756,22 +1854,22 @@ def test_splits_on_multiple_intrinsics(self): ('federated_sum_result', federated_sum_call), ('federated_secure_sum_result', federated_secure_sum_call), ] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, - building_blocks.Struct([ - building_blocks.Reference( + federated_language.framework.Struct([ + federated_language.framework.Reference( 'federated_sum_result', federated_sum_call.type_signature, ), - building_blocks.Reference( + federated_language.framework.Reference( 'federated_secure_sum_result', federated_secure_sum_call.type_signature, ), - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ), ]), @@ -1786,9 +1884,9 @@ def test_splits_on_multiple_intrinsics(self): transformations.divisive_force_align_and_split_by_intrinsics( comp, [ - intrinsic_defs.FEDERATED_SECURE_SUM, - intrinsic_defs.FEDERATED_SUM, - intrinsic_defs.FEDERATED_MEAN, + federated_language.framework.FEDERATED_SECURE_SUM, + federated_language.framework.FEDERATED_SUM, + federated_language.framework.FEDERATED_MEAN, ], before_comp_allowed_original_arg_subparameters=[ (client_data_index,), @@ -1811,8 +1909,8 @@ def test_splits_on_multiple_intrinsics(self): self.assertEqual( set(self.find_intrinsics_in_comp(intrinsic)), set([ - intrinsic_defs.FEDERATED_SUM.uri, - intrinsic_defs.FEDERATED_SECURE_SUM.uri, + federated_language.framework.FEDERATED_SUM.uri, + federated_language.framework.FEDERATED_SECURE_SUM.uri, ]), ) self.assertEmpty(self.find_intrinsics_in_comp(after)) @@ -1826,20 +1924,22 @@ def test_splits_on_multiple_intrinsics(self): self.assertLen(intrinsic.result.locals, 2) # The federated_sum call takes one arg. self.assertNotIsInstance( - intrinsic.result.locals[0][1].argument, building_blocks.Struct + intrinsic.result.locals[0][1].argument, + federated_language.framework.Struct, ) # The federated_secure_sum call takes two args. self.assertIsInstance( - intrinsic.result.locals[1][1].argument, building_blocks.Struct + intrinsic.result.locals[1][1].argument, + federated_language.framework.Struct, ) self.assertLen(intrinsic.result.locals[1][1].argument, 2) def test_cannot_split_on_chained_intrinsic(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 @@ -1847,31 +1947,33 @@ def test_cannot_split_on_chained_intrinsic(self): block_locals = [ ( 'broadcast_result_at_clients', - building_block_factory.create_federated_broadcast( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), + federated_language.framework.create_federated_broadcast( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), index=server_data_index, ) ), ), ( 'federated_sum_result', - building_block_factory.create_federated_sum( - building_blocks.Reference( + federated_language.framework.create_federated_sum( + federated_language.framework.Reference( 'broadcast_result_at_clients', - computation_types.FederatedType( - np.int32, placements.CLIENTS + federated_language.FederatedType( + np.int32, federated_language.CLIENTS ), ) ), ), ] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'arg', arg_type, - building_blocks.Block( + federated_language.framework.Block( block_locals, - building_blocks.Reference('federated_sum_result', server_val_type), + federated_language.framework.Reference( + 'federated_sum_result', server_val_type + ), ), ) @@ -1880,8 +1982,8 @@ def test_cannot_split_on_chained_intrinsic(self): transformations.divisive_force_align_and_split_by_intrinsics( comp, [ - intrinsic_defs.FEDERATED_BROADCAST, - intrinsic_defs.FEDERATED_SUM, + federated_language.framework.FEDERATED_BROADCAST, + federated_language.framework.FEDERATED_SUM, ], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], @@ -1889,35 +1991,41 @@ def test_cannot_split_on_chained_intrinsic(self): ) def test_splits_on_nested_in_tuple_broadcast(self): - server_val_type = computation_types.FederatedType( - np.int32, placements.SERVER + server_val_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - client_val_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + client_val_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) arg_type = [server_val_type, client_val_type] server_data_index = 0 - first_broadcast = building_block_factory.create_federated_broadcast( - building_blocks.Selection( - building_blocks.Reference('arg', arg_type), index=server_data_index + first_broadcast = federated_language.framework.create_federated_broadcast( + federated_language.framework.Selection( + federated_language.framework.Reference('arg', arg_type), + index=server_data_index, ) ) any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - packed_broadcast = building_blocks.Struct( - [building_blocks.Data(any_proto, server_val_type), first_broadcast] + packed_broadcast = federated_language.framework.Struct([ + federated_language.framework.Data(any_proto, server_val_type), + first_broadcast, + ]) + sel = federated_language.framework.Selection(packed_broadcast, index=0) + second_broadcast = federated_language.framework.create_federated_broadcast( + sel + ) + comp = federated_language.framework.Lambda( + 'arg', arg_type, second_broadcast ) - sel = building_blocks.Selection(packed_broadcast, index=0) - second_broadcast = building_block_factory.create_federated_broadcast(sel) - comp = building_blocks.Lambda('arg', arg_type, second_broadcast) # Allow all parts of the split to depend on the entire original comp input. before, intrinsic, after = ( transformations.divisive_force_align_and_split_by_intrinsics( comp, - [intrinsic_defs.FEDERATED_BROADCAST], + [federated_language.framework.FEDERATED_BROADCAST], before_comp_allowed_original_arg_subparameters=[()], intrinsic_comp_allowed_original_arg_subparameters=[()], after_comp_allowed_original_arg_subparameters=[()], @@ -1929,7 +2037,7 @@ def test_splits_on_nested_in_tuple_broadcast(self): self.assertEmpty(self.find_intrinsics_in_comp(before)) self.assertEqual( self.find_intrinsics_in_comp(intrinsic), - [intrinsic_defs.FEDERATED_BROADCAST.uri], + [federated_language.framework.FEDERATED_BROADCAST.uri], ) self.assertEmpty(self.find_intrinsics_in_comp(after)) diff --git a/tensorflow_federated/python/core/impl/compiler/tree_analysis.py b/tensorflow_federated/python/core/impl/compiler/tree_analysis.py deleted file mode 100644 index a0557ef296..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/tree_analysis.py +++ /dev/null @@ -1,602 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permisions and -# limitations under the License. -"""A library of static analysis functions for ASTs.""" - -from collections.abc import Callable -from typing import Optional, Union - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_block_analysis -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis - - -_TypeOrTupleOfTypes = Union[ - type[building_blocks.ComputationBuildingBlock], - tuple[type[building_blocks.ComputationBuildingBlock], ...], -] - - -def visit_preorder( - tree: building_blocks.ComputationBuildingBlock, - function: Callable[[building_blocks.ComputationBuildingBlock], None], -): - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - - def _visit(building_block): - function(building_block) - return building_block, False - - transformation_utils.transform_preorder(tree, _visit) - - -def visit_postorder( - tree: building_blocks.ComputationBuildingBlock, - function: Callable[[building_blocks.ComputationBuildingBlock], None], -): - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - - def _visit(building_block): - function(building_block) - return building_block, False - - transformation_utils.transform_postorder(tree, _visit) - - -_BuildingBlockPredicate = Callable[ - [building_blocks.ComputationBuildingBlock], bool -] - - -def count( - tree: building_blocks.ComputationBuildingBlock, - predicate: Optional[_BuildingBlockPredicate] = None, -) -> int: - """Returns the number of building blocks in `tree` matching `predicate`. - - Args: - tree: A tree of `building_blocks.ComputationBuildingBlock`s to count. - predicate: An optional Python function that takes a tree as a parameter and - returns a boolean value. If `None`, all computations are counted. - """ - counter = 0 - - def _fn(building_block): - nonlocal counter - if predicate is None or predicate(building_block): - counter += 1 - - visit_postorder(tree, _fn) - return counter - - -def contains( - tree: building_blocks.ComputationBuildingBlock, - predicate: _BuildingBlockPredicate, -) -> bool: - """Returns whether or not a building block in `tree` matches `predicate`.""" - return count(tree, predicate) != 0 - - -def contains_only( - tree: building_blocks.ComputationBuildingBlock, - predicate: _BuildingBlockPredicate, -) -> bool: - """Returns whether or not a building block in `tree` matches `predicate`.""" - return not contains(tree, lambda x: not predicate(x)) - - -def check_has_single_placement(comp, single_placement): - """Checks that the AST of `comp` contains only `single_placement`. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock`. - single_placement: Instance of `placements.PlacementLiteral` which should be - the only placement present under `comp`. - - Raises: - ValueError: If the AST under `comp` contains any - `computation_types.FederatedType` other than `single_placement`. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - py_typecheck.check_type(single_placement, placements.PlacementLiteral) - - def _check_single_placement(comp): - """Checks that the placement in `type_spec` matches `single_placement`.""" - if ( - isinstance(comp.type_signature, computation_types.FederatedType) - and comp.type_signature.placement is not single_placement - ): - raise ValueError( - 'Comp contains a placement other than {}; ' - 'placement {} on comp {} inside the structure. '.format( - single_placement, - comp.type_signature.placement, - comp.compact_representation(), - ) - ) - - visit_postorder(comp, _check_single_placement) - - -def check_contains_only_reducible_intrinsics( - comp: building_blocks.ComputationBuildingBlock, -): - """Checks that `comp` contains intrinsics reducible to aggregate or broadcast. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` to check for - presence of intrinsics not currently immediately reducible to - `FEDERATED_AGGREGATE` or `FEDERATED_BROADCAST`, or local processing. - - Raises: - ValueError: If we encounter an intrinsic under `comp` that is not reducible. - """ - reducible_uris = ( - intrinsic_defs.FEDERATED_AGGREGATE.uri, - intrinsic_defs.FEDERATED_APPLY.uri, - intrinsic_defs.FEDERATED_BROADCAST.uri, - intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri, - intrinsic_defs.FEDERATED_MAP.uri, - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, - intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri, - intrinsic_defs.FEDERATED_SECURE_SUM.uri, - intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, - intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri, - ) - - def _check(comp): - if ( - isinstance(comp, building_blocks.Intrinsic) - and comp.uri not in reducible_uris - ): - raise ValueError( - 'Encountered an Intrinsic not currently reducible to aggregate or ' - 'broadcast, the intrinsic {}'.format(comp.compact_representation()) - ) - - visit_postorder(comp, _check) - - -class NonuniqueNameError(ValueError): - - def __init__(self, comp, name): - self.comp = comp - self.name = name - message = ( - f'The name `{name}` is bound multiple times in the computation:\n' - f'{comp.compact_representation()}' - ) - super().__init__(message) - - -def check_has_unique_names(comp): - """Checks that each variable of `comp` is bound at most once. - - Additionally, checks that `comp` does not mask any names which are unbound - at the top level. - - Args: - comp: Instance of `building_blocks.ComputationBuildingBlock`. - - Raises: - NonuniqueNameError: If we encounter a name that is bound multiple times or a - binding which would shadow an unbound reference. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - # Initializing `names` to unbound names in `comp` ensures that `comp` does not - # mask any names from its parent scope. - names = transformation_utils.get_map_of_unbound_references(comp)[comp] - - def _visit_name(name): - if name in names: - raise NonuniqueNameError(comp, name) - names.add(name) - - def _visit(comp): - if isinstance(comp, building_blocks.Block): - for name, _ in comp.locals: - _visit_name(name) - elif ( - isinstance(comp, building_blocks.Lambda) - and comp.parameter_type is not None - ): - _visit_name(comp.parameter_name) - - visit_postorder(comp, _visit) - - -def extract_nodes_consuming(tree, predicate: _BuildingBlockPredicate): - """Returns the set of AST nodes which consume nodes matching `predicate`. - - Notice we adopt the convention that a node which itself satisfies the - predicate is in this set. - - Args: - tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an - abstract syntax tree, and construct the set of nodes in this tree having a - dependency on nodes matching `predicate`; that is, the set of nodes whose - value depends on evaluating nodes matching `predicate`. - predicate: One-arg callable, accepting arguments of type - `building_blocks.ComputationBuildingBlock` and returning a `bool` - indicating match or mismatch with the desired pattern. - - Returns: - A `set` of `building_blocks.ComputationBuildingBlock` instances - representing the nodes in `tree` dependent on nodes matching `predicate`. - """ - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - - class _NodeSet: - - def __init__(self): - self.mapping = {} - - def add(self, comp): - self.mapping[id(comp)] = comp - - def to_set(self): - return set(self.mapping.values()) - - dependent_nodes = _NodeSet() - - def _are_children_in_dependent_set(comp, symbol_tree): - """Checks if the dependencies of `comp` are present in `dependent_nodes`.""" - if isinstance( - comp, - ( - building_blocks.CompiledComputation, - building_blocks.Data, - building_blocks.Intrinsic, - building_blocks.Literal, - building_blocks.Placement, - ), - ): - return False - elif isinstance(comp, building_blocks.Lambda): - return id(comp.result) in dependent_nodes.mapping - elif isinstance(comp, building_blocks.Block): - return ( - any(id(x[1]) in dependent_nodes.mapping for x in comp.locals) - or id(comp.result) in dependent_nodes.mapping - ) - elif isinstance(comp, building_blocks.Struct): - return any(id(x) in dependent_nodes.mapping for x in comp) - elif isinstance(comp, building_blocks.Selection): - return id(comp.source) in dependent_nodes.mapping - elif isinstance(comp, building_blocks.Call): - return ( - id(comp.function) in dependent_nodes.mapping - or id(comp.argument) in dependent_nodes.mapping - ) - elif isinstance(comp, building_blocks.Reference): - return _is_reference_dependent(comp, symbol_tree) - - def _is_reference_dependent(comp, symbol_tree): - payload = symbol_tree.get_payload_with_name(comp.name) - if payload is None: - return False - # The postorder traversal ensures that we process any - # bindings before we process the reference to those bindings - return id(payload.value) in dependent_nodes.mapping - - def _populate_dependent_set(comp, symbol_tree): - """Populates `dependent_nodes` with all nodes dependent on `predicate`.""" - if predicate(comp): - dependent_nodes.add(comp) - elif _are_children_in_dependent_set(comp, symbol_tree): - dependent_nodes.add(comp) - return comp, False - - symbol_tree = transformation_utils.SymbolTree( - transformation_utils.ReferenceCounter - ) - transformation_utils.transform_postorder_with_symbol_bindings( - tree, _populate_dependent_set, symbol_tree - ) - return dependent_nodes.to_set() - - -def _extract_calls_with_fn_consuming_arg( - tree: building_blocks.ComputationBuildingBlock, - *, - fn_predicate: _BuildingBlockPredicate, - arg_predicate: _BuildingBlockPredicate, -) -> list[building_blocks.Call]: - """Extracts calls depending on function and arg predicates. - - This function returns all calls in `tree` whose fns consume nodes matching - `fn_predicate` and arguments consume nodes matching `arg_predicate`. This - matching can be useful in checking that one type of function does not consume - any nodes depending on another type of function in the body of `tree`. - - Args: - tree: Instance of `building_blocks.ComputationBuildingBlock` to traverse. - fn_predicate: Callable taking a building block and returning a boolean, to - define the behavior of this function according to the semantics above. - arg_predicate: Callable taking a building block and returning a boolean, to - define the behavior of this function according to the semantics above. - - Returns: - A list of `building_block.Calls` matching the description above. - """ - - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - - nodes_dependent_on_arg_predicate = extract_nodes_consuming( - tree, arg_predicate - ) - - nodes_dependent_on_fn_predicate = extract_nodes_consuming(tree, fn_predicate) - - instances = [] - - for node in nodes_dependent_on_arg_predicate: - if isinstance(node, building_blocks.Call): - if ( - node.argument in nodes_dependent_on_arg_predicate - and node.function in nodes_dependent_on_fn_predicate - ): - instances.append(node) - return instances - - -def check_broadcast_not_dependent_on_aggregate(tree): - """Raises if any broadcast in `tree` ingests the result of an aggregate. - - We explicitly check for this pattern since if it occurs, `tree` is not - reducible to broadcast-map-aggregate form. - - - Args: - tree: Instance of `building_blocks.ComputationBuildingBlock` to check for - the presence of a broadcast which ingests the result of an aggregate. - - Raises: - ValueError: If a broadcast in `tree` consumes the result of an aggregate. - """ - - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - - def aggregate_predicate(x): - return ( - isinstance(x, building_blocks.Intrinsic) - and x.intrinsic_def().aggregation_kind - ) - - def broadcast_predicate(x): - return ( - isinstance(x, building_blocks.Intrinsic) - and x.intrinsic_def().broadcast_kind - ) - - broadcast_dependent_examples = _extract_calls_with_fn_consuming_arg( - tree, fn_predicate=broadcast_predicate, arg_predicate=aggregate_predicate - ) - if broadcast_dependent_examples: - raise ValueError( - 'Detected broadcast dependent on aggregate. Examples are: {}'.format( - broadcast_dependent_examples - ) - ) - - -def check_aggregate_not_dependent_on_aggregate(tree): - """Raises if any aggregation in `tree` ingests the result of an aggregate. - - We explicitly check for this pattern since if it occurs, `tree` is not - reducible to `MergeableCompForm`. - - Args: - tree: Instance of `building_blocks.ComputationBuildingBlock` to check for - the presence of an aggregation which ingests the result of another - aggregate. - - Raises: - ValueError: If a broadcast in `tree` consumes the result of an aggregate. - """ - - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - - def aggregate_predicate(x): - return ( - isinstance(x, building_blocks.Intrinsic) - and x.intrinsic_def().aggregation_kind - ) - - multiple_agg_dependent_examples = _extract_calls_with_fn_consuming_arg( - tree, fn_predicate=aggregate_predicate, arg_predicate=aggregate_predicate - ) - if multiple_agg_dependent_examples: - raise ValueError( - 'Detected one aggregate dependent on another. Examples are: {}'.format( - multiple_agg_dependent_examples - ) - ) - - -def check_contains_no_unbound_references(tree, excluding=None): - """Checks that `tree` has no unbound references. - - Args: - tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an - abstract syntax tree. - excluding: A `string` or a collection of `string`s representing the names of - references to exclude from the test. - - Raises: - ValueError: If `comp` has unbound references. - """ - if not contains_no_unbound_references(tree, excluding): - raise ValueError( - 'The AST contains unbound references: {}.'.format( - tree.formatted_representation() - ) - ) - - -def check_contains_no_new_unbound_references(old_tree, new_tree): - """Checks that `new_tree` contains no unbound references not in `old_tree`.""" - old_unbound = transformation_utils.get_map_of_unbound_references(old_tree)[ - old_tree - ] - new_unbound = transformation_utils.get_map_of_unbound_references(new_tree)[ - new_tree - ] - diff = new_unbound - old_unbound - if diff: - raise ValueError( - 'Expected no new unbounded references. ' - f'Old tree:\n{old_tree}\nNew tree:\n{new_tree}\n' - f'New unbound references: {diff}' - ) - - -def contains_called_intrinsic(tree, uri=None): - """Tests if `tree` contains a called intrinsic for the given `uri`. - - Args: - tree: A `building_blocks.ComputationBuildingBlock`. - uri: An optional URI or list of URIs; the same as what is accepted by - `building_block_analysis.is_called_intrinsic`. - - Returns: - `True` if there is a called intrinsic in `tree` for the given `uri`, - otherwise `False`. - """ - predicate = lambda x: building_block_analysis.is_called_intrinsic(x, uri) - return count(tree, predicate) > 0 - - -def contains_no_unbound_references(tree, excluding=None): - """Tests if all the references in `tree` are bound by `tree`. - - Args: - tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an - abstract syntax tree. - excluding: A `string` or a collection of `string`s representing the names of - references to exclude from the test. - - Returns: - `True` if there are no unbound references in `tree` excluding those - specified by `excluding`, otherwise `False`. - """ - py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock) - if isinstance(excluding, str): - excluding = [excluding] - unbound_references = transformation_utils.get_map_of_unbound_references(tree) - if excluding is not None: - excluding = set(excluding) - names = unbound_references[tree] - excluding - else: - names = unbound_references[tree] - num_unbound_references = len(names) - return num_unbound_references == 0 - - -_DEFAULT_KIND_PREDICATE = lambda k: k is not None - - -def find_aggregations_in_tree( - comp, - kind_predicate: Callable[ - [intrinsic_defs.AggregationKind], bool - ] = _DEFAULT_KIND_PREDICATE, -) -> list[building_blocks.Call]: - """Finds aggregating calls with kind matching `kind_predicate` in `comp`. - - An "aggregating call" for the purpose of this function is a call to an - intrinsic which takes values at CLIENT and materializes some result at - SERVER. - - Args: - comp: An AST to search. - kind_predicate: A filter for kind of aggregation to search for. - - Returns: - A list of child ASTs which are calls to aggregating intrinsics with kinds - matching `aggregation_kind`. - - Raises: - ValueError if `comp` contains a call whose target function cannot be - identified. This may result from calls to references or other - indirect structures. - """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - aggregation_calls: list[building_blocks.Call] = [] - - def record_intrinsic_calls(comp): - """Identifies matching calls and adds them to `aggregation_calls`.""" - if not isinstance(comp, building_blocks.Call): - return - # Called lambdas will only trigger aggregation if they themselves contain - # aggregation, which will be caught when the lambea itself is traversed. - if isinstance(comp.function, building_blocks.Lambda): - return - # Aggregation cannot be occurring if the output type is not federated - if not type_analysis.contains_federated_types( - comp.function.type_signature.result - ): - return - - # We can't tell whether an arbitrary AST fragment results in an intrinsic - # with a given URI, so we report an error in this case. - if not isinstance(comp.function, building_blocks.Intrinsic): - raise ValueError( - 'Cannot determine whether call contains aggregation: ' + str(comp) - ) - - # Aggregation with inputs that don't contain any tensors isn't interesting. - # - # NOTE: this is only applicable to intrinsic calls. Users can write their - # own functions that internally materialize values at clients + aggregate - # without taking any input tensors. - # - # This means that this check *must* come after the check above ensuring - # that we're only talking about calls to `building_blocks.Intrinsic`s. - if comp.argument is None or not type_analysis.contains_tensor_types( - comp.argument.type_signature - ): - return - - if kind_predicate(comp.function.intrinsic_def().aggregation_kind): - aggregation_calls.append(comp) - - visit_postorder(comp, record_intrinsic_calls) - return aggregation_calls - - -def find_secure_aggregation_in_tree( - comp: building_blocks.ComputationBuildingBlock, -) -> list[building_blocks.Call]: - """See documentation on `tree_contains_aggregation` for details.""" - return find_aggregations_in_tree( - comp, lambda kind: kind == intrinsic_defs.AggregationKind.SECURE - ) - - -def find_unsecure_aggregation_in_tree( - comp: building_blocks.ComputationBuildingBlock, -) -> list[building_blocks.Call]: - """See documentation on `tree_contains_aggregation` for details.""" - return find_aggregations_in_tree( - comp, lambda kind: kind == intrinsic_defs.AggregationKind.DEFAULT - ) diff --git a/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py b/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py deleted file mode 100644 index fa2932ba1b..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py +++ /dev/null @@ -1,578 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_block_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -class TestCheckContainsOnlyReducibleIntrinsics(absltest.TestCase): - - def test_raises_on_none(self): - with self.assertRaises(TypeError): - tree_analysis.check_contains_only_reducible_intrinsics(None) - - def test_passes_with_federated_map(self): - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MAP.uri, - computation_types.FunctionType( - [ - computation_types.FunctionType(np.int32, np.float32), - computation_types.FederatedType(np.int32, placements.CLIENTS), - ], - computation_types.FederatedType(np.float32, placements.CLIENTS), - ), - ) - tree_analysis.check_contains_only_reducible_intrinsics(intrinsic) - - def test_raises_with_federated_mean(self): - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MEAN.uri, - computation_types.FunctionType( - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.SERVER), - ), - ) - - with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()): - tree_analysis.check_contains_only_reducible_intrinsics(intrinsic) - - -def whimsy_intrinsic_predicate(x): - return ( - isinstance(x, building_blocks.Intrinsic) and x.uri == 'whimsy_intrinsic' - ) - - -class NodesDependentOnPredicateTest(absltest.TestCase): - - def test_raises_on_none_comp(self): - with self.assertRaises(TypeError): - tree_analysis.extract_nodes_consuming(None, lambda x: True) - - def test_raises_on_none_predicate(self): - data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - with self.assertRaises(TypeError): - tree_analysis.extract_nodes_consuming(data, None) - - def test_adds_all_nodes_to_set_with_constant_true_predicate(self): - nested_tree = building_block_test_utils.create_nested_syntax_tree() - all_nodes = tree_analysis.extract_nodes_consuming( - nested_tree, lambda x: True - ) - node_count = tree_analysis.count(nested_tree) - self.assertLen(all_nodes, node_count) - - def test_adds_no_nodes_to_set_with_constant_false_predicate(self): - nested_tree = building_block_test_utils.create_nested_syntax_tree() - all_nodes = tree_analysis.extract_nodes_consuming( - nested_tree, lambda x: False - ) - self.assertEmpty(all_nodes) - - def test_propogates_dependence_up_through_lambda(self): - type_signature = computation_types.TensorType(np.int32) - whimsy_intrinsic = building_blocks.Intrinsic( - 'whimsy_intrinsic', type_signature - ) - lam = building_blocks.Lambda('x', np.int32, whimsy_intrinsic) - dependent_nodes = tree_analysis.extract_nodes_consuming( - lam, whimsy_intrinsic_predicate - ) - self.assertIn(lam, dependent_nodes) - - def test_propogates_dependence_up_through_block_result(self): - type_signature = computation_types.TensorType(np.int32) - whimsy_intrinsic = building_blocks.Intrinsic( - 'whimsy_intrinsic', type_signature - ) - integer_reference = building_blocks.Reference('int', np.int32) - block = building_blocks.Block([('x', integer_reference)], whimsy_intrinsic) - dependent_nodes = tree_analysis.extract_nodes_consuming( - block, whimsy_intrinsic_predicate - ) - self.assertIn(block, dependent_nodes) - - def test_propogates_dependence_up_through_block_locals(self): - type_signature = computation_types.TensorType(np.int32) - whimsy_intrinsic = building_blocks.Intrinsic( - 'whimsy_intrinsic', type_signature - ) - integer_reference = building_blocks.Reference('int', np.int32) - block = building_blocks.Block([('x', whimsy_intrinsic)], integer_reference) - dependent_nodes = tree_analysis.extract_nodes_consuming( - block, whimsy_intrinsic_predicate - ) - self.assertIn(block, dependent_nodes) - - def test_propogates_dependence_up_through_tuple(self): - type_signature = computation_types.TensorType(np.int32) - whimsy_intrinsic = building_blocks.Intrinsic( - 'whimsy_intrinsic', type_signature - ) - integer_reference = building_blocks.Reference('int', np.int32) - tup = building_blocks.Struct([integer_reference, whimsy_intrinsic]) - dependent_nodes = tree_analysis.extract_nodes_consuming( - tup, whimsy_intrinsic_predicate - ) - self.assertIn(tup, dependent_nodes) - - def test_propogates_dependence_up_through_selection(self): - type_signature = computation_types.StructType([np.int32]) - whimsy_intrinsic = building_blocks.Intrinsic( - 'whimsy_intrinsic', type_signature - ) - selection = building_blocks.Selection(whimsy_intrinsic, index=0) - dependent_nodes = tree_analysis.extract_nodes_consuming( - selection, whimsy_intrinsic_predicate - ) - self.assertIn(selection, dependent_nodes) - - def test_propogates_dependence_up_through_call(self): - type_signature = computation_types.TensorType(np.int32) - whimsy_intrinsic = building_blocks.Intrinsic( - 'whimsy_intrinsic', type_signature - ) - ref_to_x = building_blocks.Reference('x', np.int32) - identity_lambda = building_blocks.Lambda('x', np.int32, ref_to_x) - called_lambda = building_blocks.Call(identity_lambda, whimsy_intrinsic) - dependent_nodes = tree_analysis.extract_nodes_consuming( - called_lambda, whimsy_intrinsic_predicate - ) - self.assertIn(called_lambda, dependent_nodes) - - def test_propogates_dependence_into_binding_to_reference(self): - fed_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - ref_to_x = building_blocks.Reference('x', fed_type) - federated_zero = building_blocks.Intrinsic( - intrinsic_defs.GENERIC_ZERO.uri, fed_type - ) - - def federated_zero_predicate(x): - return ( - isinstance(x, building_blocks.Intrinsic) - and x.uri == intrinsic_defs.GENERIC_ZERO.uri - ) - - block = building_blocks.Block([('x', federated_zero)], ref_to_x) - dependent_nodes = tree_analysis.extract_nodes_consuming( - block, federated_zero_predicate - ) - self.assertIn(ref_to_x, dependent_nodes) - - -class BroadcastDependentOnAggregateTest(absltest.TestCase): - - def test_raises_on_none_comp(self): - with self.assertRaises(TypeError): - tree_analysis.check_broadcast_not_dependent_on_aggregate(None) - - def test_does_not_find_aggregate_dependent_on_broadcast(self): - broadcast = ( - building_block_test_utils.create_whimsy_called_federated_broadcast() - ) - value_type = broadcast.type_signature - zero = building_blocks.Literal(1, value_type.member) - accumulate_result = building_blocks.Literal(2, value_type.member) - accumulate = building_blocks.Lambda( - 'accumulate_parameter', - [value_type.member, value_type.member], - accumulate_result, - ) - merge_result = building_blocks.Literal(3, value_type.member) - merge = building_blocks.Lambda( - 'merge_parameter', [value_type.member, value_type.member], merge_result - ) - report_result = building_blocks.Literal(4, value_type.member) - report = building_blocks.Lambda( - 'report_parameter', value_type.member, report_result - ) - aggregate_dependent_on_broadcast = ( - building_block_factory.create_federated_aggregate( - broadcast, zero, accumulate, merge, report - ) - ) - tree_analysis.check_broadcast_not_dependent_on_aggregate( - aggregate_dependent_on_broadcast - ) - - def test_finds_broadcast_dependent_on_aggregate(self): - aggregate = ( - building_block_test_utils.create_whimsy_called_federated_aggregate() - ) - broadcasted_aggregate = building_block_factory.create_federated_broadcast( - aggregate - ) - with self.assertRaises(ValueError): - tree_analysis.check_broadcast_not_dependent_on_aggregate( - broadcasted_aggregate - ) - - def test_returns_correct_example_of_broadcast_dependent_on_aggregate(self): - aggregate = ( - building_block_test_utils.create_whimsy_called_federated_aggregate() - ) - broadcasted_aggregate = building_block_factory.create_federated_broadcast( - aggregate - ) - with self.assertRaisesRegex(ValueError, 'acc_param'): - tree_analysis.check_broadcast_not_dependent_on_aggregate( - broadcasted_aggregate - ) - - -class AggregateDependentOnAggregateTest(absltest.TestCase): - - def test_raises_on_none_comp(self): - with self.assertRaises(TypeError): - tree_analysis.check_aggregate_not_dependent_on_aggregate(None) - - def test_does_not_find_aggregate_dependent_on_broadcast(self): - broadcast = ( - building_block_test_utils.create_whimsy_called_federated_broadcast() - ) - value_type = broadcast.type_signature - zero = building_blocks.Literal(1, value_type.member) - accumulate_result = building_blocks.Literal(2, value_type.member) - accumulate = building_blocks.Lambda( - 'accumulate_parameter', - [value_type.member, value_type.member], - accumulate_result, - ) - merge_result = building_blocks.Literal(3, value_type.member) - merge = building_blocks.Lambda( - 'merge_parameter', [value_type.member, value_type.member], merge_result - ) - report_result = building_blocks.Literal(4, value_type.member) - report = building_blocks.Lambda( - 'report_parameter', value_type.member, report_result - ) - aggregate_dependent_on_broadcast = ( - building_block_factory.create_federated_aggregate( - broadcast, zero, accumulate, merge, report - ) - ) - tree_analysis.check_aggregate_not_dependent_on_aggregate( - aggregate_dependent_on_broadcast - ) - - def test_finds_aggregate_dependent_on_aggregate(self): - aggregate = ( - building_block_test_utils.create_whimsy_called_federated_aggregate() - ) - broadcasted_aggregate = building_block_factory.create_federated_broadcast( - aggregate - ) - second_aggregate = building_block_factory.create_federated_sum( - broadcasted_aggregate - ) - with self.assertRaises(ValueError): - tree_analysis.check_aggregate_not_dependent_on_aggregate(second_aggregate) - - -class ContainsCalledIntrinsic(absltest.TestCase): - - def test_raises_type_error_with_none_tree(self): - with self.assertRaises(TypeError): - tree_analysis.contains_called_intrinsic(None) - - def test_returns_true_with_none_uri(self): - comp = building_block_test_utils.create_whimsy_called_federated_broadcast() - self.assertTrue(tree_analysis.contains_called_intrinsic(comp)) - - def test_returns_true_with_matching_uri(self): - comp = building_block_test_utils.create_whimsy_called_federated_broadcast() - uri = intrinsic_defs.FEDERATED_BROADCAST.uri - self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uri)) - - def test_returns_false_with_no_called_intrinsic(self): - comp = building_block_test_utils.create_identity_function('a') - self.assertFalse(tree_analysis.contains_called_intrinsic(comp)) - - def test_returns_false_with_unmatched_called_intrinsic(self): - comp = building_block_test_utils.create_whimsy_called_federated_broadcast() - uri = intrinsic_defs.FEDERATED_MAP.uri - self.assertFalse(tree_analysis.contains_called_intrinsic(comp, uri)) - - -class ContainsNoUnboundReferencesTest(absltest.TestCase): - - def test_raises_type_error_with_none_tree(self): - with self.assertRaises(TypeError): - tree_analysis.contains_no_unbound_references(None) - - def test_raises_type_error_with_int_excluding(self): - ref = building_blocks.Reference('a', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - with self.assertRaises(TypeError): - tree_analysis.contains_no_unbound_references(fn, 1) - - def test_returns_true(self): - ref = building_blocks.Reference('a', np.int32) - fn = building_blocks.Lambda(ref.name, ref.type_signature, ref) - self.assertTrue(tree_analysis.contains_no_unbound_references(fn)) - - def test_returns_true_with_excluded_reference(self): - ref = building_blocks.Reference('a', np.int32) - fn = building_blocks.Lambda('b', np.int32, ref) - self.assertTrue( - tree_analysis.contains_no_unbound_references(fn, excluding='a') - ) - - def test_returns_false(self): - ref = building_blocks.Reference('a', np.int32) - fn = building_blocks.Lambda('b', np.int32, ref) - self.assertFalse(tree_analysis.contains_no_unbound_references(fn)) - - -def _create_trivial_mean(value_type=np.int32): - """Returns a trivial federated mean.""" - fed_value_type = computation_types.FederatedType( - value_type, placements.CLIENTS - ) - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3]) - ) - values = building_blocks.Data(any_proto, fed_value_type) - - return building_block_factory.create_federated_mean(values, None) - - -def _create_trivial_secure_sum(value_type=np.int32): - """Returns a trivial secure sum.""" - federated_type = computation_types.FederatedType( - value_type, placements.CLIENTS - ) - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3]) - ) - value = building_blocks.Data(any_proto, federated_type) - bitwidth = building_blocks.Data(any_proto, value_type) - return building_block_factory.create_federated_secure_sum_bitwidth( - value, bitwidth - ) - - -non_aggregation_intrinsics = building_blocks.Struct([ - ( - None, - building_block_test_utils.create_whimsy_called_federated_broadcast(), - ), - ( - None, - building_block_test_utils.create_whimsy_called_federated_value( - placements.CLIENTS - ), - ), -]) - -trivial_mean = _create_trivial_mean(value_type=computation_types.StructType([])) -trivial_secure_sum = _create_trivial_secure_sum( - value_type=computation_types.StructType([]) -) - - -class ContainsAggregationShared(parameterized.TestCase): - - @parameterized.named_parameters([ - ('non_aggregation_intrinsics', non_aggregation_intrinsics), - ('trivial_mean', trivial_mean), - # TODO: b/120439632 - Enable once federated_mean accepts structured - # weight. - # ('trivial_weighted_mean', trivial_weighted_mean), - ('trivial_secure_sum', trivial_secure_sum), - ]) - def test_returns_none(self, comp): - self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp)) - self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp)) - - def test_throws_on_unresolvable_function_call(self): - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3]) - ) - comp = building_blocks.Call( - building_blocks.Data( - any_proto, - computation_types.FunctionType( - None, - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ) - ) - with self.assertRaises(ValueError): - tree_analysis.find_unsecure_aggregation_in_tree(comp) - with self.assertRaises(ValueError): - tree_analysis.find_secure_aggregation_in_tree(comp) - - # functions without a federated output can't aggregate - def test_returns_none_on_unresolvable_function_call_with_non_federated_output( - self, - ): - input_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - output_type = np.int32 - any_proto = building_block_test_utils.create_any_proto_from_array( - np.array([1, 2, 3]) - ) - comp = building_blocks.Call( - building_blocks.Data( - any_proto, - computation_types.FunctionType(input_type, output_type), - ), - building_blocks.Data(any_proto, input_type), - ) - - self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp)) - self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp)) - - -simple_aggregate = ( - building_block_test_utils.create_whimsy_called_federated_aggregate() -) -simple_mean = building_block_test_utils.create_whimsy_called_federated_mean() -simple_sum = building_block_test_utils.create_whimsy_called_federated_sum() -simple_weighted_mean = ( - building_block_test_utils.create_whimsy_called_federated_mean( - np.float32, np.float32 - ) -) -simple_secure_sum = ( - building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth() -) - - -class ContainsSecureAggregation(parameterized.TestCase): - - @parameterized.named_parameters([ - ('simple_aggregate', simple_aggregate), - ('simple_mean', simple_mean), - ('simple_sum', simple_sum), - ('simple_weighted_mean', simple_weighted_mean), - ]) - def test_returns_none_on_unsecure_aggregation(self, comp): - self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp)) - - def assert_one_aggregation(self, comp): - self.assertLen(tree_analysis.find_secure_aggregation_in_tree(comp), 1) - - def test_returns_str_on_simple_secure_aggregation(self): - self.assert_one_aggregation(simple_secure_sum) - - def test_returns_str_on_nested_secure_aggregation(self): - comp = _create_trivial_secure_sum((np.int32, np.int32)) - self.assert_one_aggregation(comp) - - -class ContainsUnsecureAggregation(parameterized.TestCase): - - def test_returns_none_on_secure_aggregation(self): - self.assertEmpty( - tree_analysis.find_unsecure_aggregation_in_tree(simple_secure_sum) - ) - - @parameterized.named_parameters([ - ('simple_aggregate', simple_aggregate), - ('simple_mean', simple_mean), - ('simple_sum', simple_sum), - ('simple_weighted_mean', simple_weighted_mean), - ]) - def test_returns_one_on_unsecure_aggregation(self, comp): - self.assertLen(tree_analysis.find_unsecure_aggregation_in_tree(comp), 1) - - -class CheckHasUniqueNamesTest(absltest.TestCase): - - def test_raises_on_none(self): - with self.assertRaises(TypeError): - tree_analysis.check_has_unique_names(None) - - def test_ok_on_single_lambda(self): - ref_to_x = building_blocks.Reference('x', np.int32) - lambda_1 = building_blocks.Lambda('x', np.int32, ref_to_x) - tree_analysis.check_has_unique_names(lambda_1) - - def test_ok_on_multiple_no_arg_lambdas(self): - data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - lambda_1 = building_blocks.Lambda(None, None, data) - lambda_2 = building_blocks.Lambda(None, None, data) - tup = building_blocks.Struct([lambda_1, lambda_2]) - tree_analysis.check_has_unique_names(tup) - - def test_raises_on_nested_lambdas_with_same_variable_name(self): - ref_to_x = building_blocks.Reference('x', np.int32) - lambda_1 = building_blocks.Lambda('x', np.int32, ref_to_x) - lambda_2 = building_blocks.Lambda('x', np.int32, lambda_1) - with self.assertRaises(tree_analysis.NonuniqueNameError): - tree_analysis.check_has_unique_names(lambda_2) - - def test_ok_on_nested_lambdas_with_different_variable_name(self): - ref_to_x = building_blocks.Reference('x', np.int32) - lambda_1 = building_blocks.Lambda('x', np.int32, ref_to_x) - lambda_2 = building_blocks.Lambda('y', np.int32, lambda_1) - tree_analysis.check_has_unique_names(lambda_2) - - def test_ok_on_single_block(self): - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - single_block = building_blocks.Block([('x', x_data)], x_data) - tree_analysis.check_has_unique_names(single_block) - - def test_raises_on_sequential_binding_of_same_variable_in_block(self): - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block = building_blocks.Block([('x', x_data), ('x', x_data)], x_data) - with self.assertRaises(tree_analysis.NonuniqueNameError): - tree_analysis.check_has_unique_names(block) - - def test_ok_on_sequential_binding_of_different_variable_in_block(self): - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block = building_blocks.Block([('x', x_data), ('y', x_data)], x_data) - tree_analysis.check_has_unique_names(block) - - def test_raises_block_rebinding_of_lambda_variable(self): - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - single_block = building_blocks.Block([('x', x_data)], x_data) - lambda_1 = building_blocks.Lambda('x', np.int32, single_block) - with self.assertRaises(tree_analysis.NonuniqueNameError): - tree_analysis.check_has_unique_names(lambda_1) - - def test_ok_block_binding_of_new_variable(self): - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - single_block = building_blocks.Block([('x', x_data)], x_data) - lambda_1 = building_blocks.Lambda('y', np.int32, single_block) - tree_analysis.check_has_unique_names(lambda_1) - - def test_raises_lambda_rebinding_of_block_variable(self): - x_ref = building_blocks.Reference('x', np.int32) - lambda_1 = building_blocks.Lambda('x', np.int32, x_ref) - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - single_block = building_blocks.Block([('x', x_data)], lambda_1) - with self.assertRaises(tree_analysis.NonuniqueNameError): - tree_analysis.check_has_unique_names(single_block) - - def test_ok_lambda_binding_of_new_variable(self): - y_ref = building_blocks.Reference('y', np.int32) - lambda_1 = building_blocks.Lambda('y', np.int32, y_ref) - x_data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - single_block = building_blocks.Block([('x', x_data)], lambda_1) - tree_analysis.check_has_unique_names(single_block) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/tree_transformations.py b/tensorflow_federated/python/core/impl/compiler/tree_transformations.py index 49ee99adab..f301293991 100644 --- a/tensorflow_federated/python/core/impl/compiler/tree_transformations.py +++ b/tensorflow_federated/python/core/impl/compiler/tree_transformations.py @@ -15,16 +15,9 @@ from collections.abc import Sequence +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_block_analysis -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_transformations def remove_mapped_or_applied_identity(comp): @@ -59,23 +52,25 @@ def remove_mapped_or_applied_identity(comp): Raises: TypeError: If types do not match. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) def _should_transform(comp): """Returns `True` if `comp` is a mapped or applied identity function.""" if ( - isinstance(comp, building_blocks.Call) - and isinstance(comp.function, building_blocks.Intrinsic) + isinstance(comp, federated_language.framework.Call) + and isinstance(comp.function, federated_language.framework.Intrinsic) and comp.function.uri in ( - intrinsic_defs.FEDERATED_MAP.uri, - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, - intrinsic_defs.FEDERATED_APPLY.uri, - intrinsic_defs.SEQUENCE_MAP.uri, + federated_language.framework.FEDERATED_MAP.uri, + federated_language.framework.FEDERATED_MAP_ALL_EQUAL.uri, + federated_language.framework.FEDERATED_APPLY.uri, + federated_language.framework.SEQUENCE_MAP.uri, ) ): called_function = comp.argument[0] - return building_block_analysis.is_identity_function(called_function) + return federated_language.framework.is_identity_function(called_function) return False def _transform(comp): @@ -84,21 +79,23 @@ def _transform(comp): transformed_comp = comp.argument[1] return transformed_comp, True - return transformation_utils.transform_postorder(comp, _transform) + return federated_language.framework.transform_postorder(comp, _transform) -class RemoveUnusedBlockLocals(transformation_utils.TransformSpec): +class RemoveUnusedBlockLocals(federated_language.framework.TransformSpec): """Removes block local variables which are not used in the result.""" def should_transform(self, comp): - return isinstance(comp, building_blocks.Block) + return isinstance(comp, federated_language.framework.Block) def transform(self, comp): if not self.should_transform(comp): return comp, False - unbound_ref_set = transformation_utils.get_map_of_unbound_references( - comp.result - )[comp.result] + unbound_ref_set = ( + federated_language.framework.get_map_of_unbound_references(comp.result)[ + comp.result + ] + ) if (not unbound_ref_set) or (not comp.locals): return comp.result, True new_locals = [] @@ -106,19 +103,22 @@ def transform(self, comp): if name in unbound_ref_set: new_locals.append((name, val)) unbound_ref_set = unbound_ref_set.union( - transformation_utils.get_map_of_unbound_references(val)[val] + federated_language.framework.get_map_of_unbound_references(val)[val] ) unbound_ref_set.discard(name) if len(new_locals) == len(comp.locals): return comp, False elif not new_locals: return comp.result, True - return building_blocks.Block(reversed(new_locals), comp.result), True + return ( + federated_language.framework.Block(reversed(new_locals), comp.result), + True, + ) def remove_unused_block_locals(comp): transform_spec = RemoveUnusedBlockLocals() - return transformation_utils.transform_postorder( + return federated_language.framework.transform_postorder( comp, transform_spec.transform ) @@ -139,12 +139,14 @@ def uniquify_reference_names(comp, name_generator=None): are guaranteed to be unique, and are guaranteed to not mask any unbound names referenced in the body of `comp`. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) # Passing `comp` to `unique_name_generator` here will ensure that the # generated names conflict with neither bindings in `comp` nor unbound # references in `comp`. if name_generator is None: - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) rename_all = False else: # If a `name_generator` was passed in, all bindings must be renamed since @@ -152,8 +154,8 @@ def uniquify_reference_names(comp, name_generator=None): rename_all = True used_names = set() - class _RenameNode(transformation_utils.BoundVariableTracker): - """transformation_utils.SymbolTree node for renaming References in ASTs.""" + class _RenameNode(federated_language.framework.BoundVariableTracker): + """federated_language.framework.SymbolTree node for renaming References in ASTs.""" def __init__(self, name, value): super().__init__(name, value) @@ -171,7 +173,7 @@ def __str__(self): def _transform(comp, context_tree): """Renames References in `comp` to unique names.""" - if isinstance(comp, building_blocks.Reference): + if isinstance(comp, federated_language.framework.Reference): payload = context_tree.get_payload_with_name(comp.name) if payload is None: return comp, False @@ -179,12 +181,12 @@ def _transform(comp, context_tree): if new_name is comp.name: return comp, False return ( - building_blocks.Reference( + federated_language.framework.Reference( new_name, comp.type_signature, comp.context ), True, ) - elif isinstance(comp, building_blocks.Block): + elif isinstance(comp, federated_language.framework.Block): new_locals = [] modified = False for name, val in comp.locals: @@ -192,8 +194,11 @@ def _transform(comp, context_tree): new_name = context_tree.get_payload_with_name(name).new_name modified = modified or (new_name is not name) new_locals.append((new_name, val)) - return building_blocks.Block(new_locals, comp.result), modified - elif isinstance(comp, building_blocks.Lambda): + return ( + federated_language.framework.Block(new_locals, comp.result), + modified, + ) + elif isinstance(comp, federated_language.framework.Lambda): if comp.parameter_type is None: return comp, False context_tree.walk_down_one_variable_binding() @@ -203,13 +208,15 @@ def _transform(comp, context_tree): if new_name is comp.parameter_name: return comp, False return ( - building_blocks.Lambda(new_name, comp.parameter_type, comp.result), + federated_language.framework.Lambda( + new_name, comp.parameter_type, comp.result + ), True, ) return comp, False - symbol_tree = transformation_utils.SymbolTree(_RenameNode) - return transformation_utils.transform_postorder_with_symbol_bindings( + symbol_tree = federated_language.framework.SymbolTree(_RenameNode) + return federated_language.framework.transform_postorder_with_symbol_bindings( comp, _transform, symbol_tree ) @@ -234,7 +241,8 @@ def normalize_types(comp, normalize_all_equal_bit: bool = True): intrinsic. Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` to transform. + comp: Instance of `federated_language.framework.ComputationBuildingBlock` to + transform. normalize_all_equal_bit: Whether to normalize `all_equal` bits in the placed values. Should be set to true when compiling for MapReduceForm and false when compiling for DistributeAggregateForm. @@ -242,26 +250,28 @@ def normalize_types(comp, normalize_all_equal_bit: bool = True): Returns: A modified version of `comp` with normalized types. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) def _normalize_type_signature_helper(type_signature): - if isinstance(type_signature, computation_types.FederatedType): + if isinstance(type_signature, federated_language.FederatedType): if normalize_all_equal_bit: - return computation_types.FederatedType( + return federated_language.FederatedType( type_signature.member, type_signature.placement ) - elif isinstance(type_signature, computation_types.StructType): + elif isinstance(type_signature, federated_language.StructType): new_elements = [] for element_name, element_type in type_signature.items(): new_elements.append( (element_name, _normalize_type_signature_helper(element_type)) ) - return computation_types.StructType(new_elements) + return federated_language.StructType(new_elements) return type_signature def _normalize_reference_bit(comp): return ( - building_blocks.Reference( + federated_language.framework.Reference( comp.name, _normalize_type_signature_helper(comp.type_signature) ), True, @@ -271,7 +281,7 @@ def _normalize_lambda_bit(comp): # Note that the lambda result has already been normalized due to the post- # order traversal. return ( - building_blocks.Lambda( + federated_language.framework.Lambda( comp.parameter_name, _normalize_type_signature_helper(comp.parameter_type), comp.result, @@ -281,46 +291,49 @@ def _normalize_lambda_bit(comp): def _normalize_intrinsic_bit(comp): """Replaces federated map all equal with federated map.""" - if comp.uri != intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri: + if comp.uri != federated_language.framework.FEDERATED_MAP_ALL_EQUAL.uri: return comp, False parameter_type = [ comp.type_signature.parameter[0], - computation_types.FederatedType( - comp.type_signature.parameter[1].member, placements.CLIENTS + federated_language.FederatedType( + comp.type_signature.parameter[1].member, federated_language.CLIENTS ), ] - intrinsic_type = computation_types.FunctionType( + intrinsic_type = federated_language.FunctionType( parameter_type, - computation_types.FederatedType( - comp.type_signature.result.member, placements.CLIENTS + federated_language.FederatedType( + comp.type_signature.result.member, federated_language.CLIENTS ), ) - new_intrinsic = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type + new_intrinsic = federated_language.framework.Intrinsic( + federated_language.framework.FEDERATED_MAP.uri, intrinsic_type ) return new_intrinsic, True def _transform_switch(comp): - if isinstance(comp, building_blocks.Reference): + if isinstance(comp, federated_language.framework.Reference): return _normalize_reference_bit(comp) - elif isinstance(comp, building_blocks.Lambda): + elif isinstance(comp, federated_language.framework.Lambda): return _normalize_lambda_bit(comp) elif ( - isinstance(comp, building_blocks.Intrinsic) and normalize_all_equal_bit + isinstance(comp, federated_language.framework.Intrinsic) + and normalize_all_equal_bit ): return _normalize_intrinsic_bit(comp) return comp, False - return transformation_utils.transform_postorder(comp, _transform_switch)[0] + return federated_language.framework.transform_postorder( + comp, _transform_switch + )[0] def replace_selections( - bb: building_blocks.ComputationBuildingBlock, + bb: federated_language.framework.ComputationBuildingBlock, ref_name: str, path_to_replacement: dict[ - tuple[int, ...], building_blocks.ComputationBuildingBlock + tuple[int, ...], federated_language.framework.ComputationBuildingBlock ], -) -> building_blocks.ComputationBuildingBlock: +) -> federated_language.framework.ComputationBuildingBlock: """Identifies selection pattern and replaces with new binding. Note that this function is somewhat brittle in that it only replaces AST @@ -336,9 +349,9 @@ def replace_selections( computations, which we error on. Args: - bb: Instance of `building_blocks.ComputationBuildingBlock` in which we wish - to replace the selections from reference `ref_name` with any path in - `paths_to_replacement` with the corresponding building block. + bb: Instance of `federated_language.framework.ComputationBuildingBlock` in + which we wish to replace the selections from reference `ref_name` with any + path in `paths_to_replacement` with the corresponding building block. ref_name: Name of the reference to look for selections from. path_to_replacement: A map from selection path to the building block with which to replace the selection. Note; it is not valid to specify @@ -353,14 +366,14 @@ def _replace(inner_bb): # Start with an empty selection path = [] selection = inner_bb - while isinstance(selection, building_blocks.Selection): + while isinstance(selection, federated_language.framework.Selection): path.append(selection.as_index()) selection = selection.source # In ASTs like x[0][1], we'll see the last (outermost) selection first. path.reverse() path = tuple(path) if ( - isinstance(selection, building_blocks.Reference) + isinstance(selection, federated_language.framework.Reference) and selection.name == ref_name and path in path_to_replacement and path_to_replacement[path].type_signature.is_equivalent_to( @@ -369,10 +382,14 @@ def _replace(inner_bb): ): return path_to_replacement[path], True if ( - isinstance(inner_bb, building_blocks.Call) - and isinstance(inner_bb.function, building_blocks.CompiledComputation) + isinstance(inner_bb, federated_language.framework.Call) + and isinstance( + inner_bb.function, federated_language.framework.CompiledComputation + ) and inner_bb.argument is not None - and isinstance(inner_bb.argument, building_blocks.Reference) + and isinstance( + inner_bb.argument, federated_language.framework.Reference + ) and inner_bb.argument.name == ref_name ): raise ValueError( @@ -388,7 +405,7 @@ def _replace(inner_bb): # protection against triggering multiple replacements for nested selections # (the type signature check above does provide one layer of protection # already). - result, _ = transformation_utils.transform_postorder(bb, _replace) + result, _ = federated_language.framework.transform_postorder(bb, _replace) return result @@ -404,9 +421,9 @@ def __init__(self, path, bb): def as_function_of_some_subparameters( - bb: building_blocks.Lambda, + bb: federated_language.framework.Lambda, paths: Sequence[Sequence[int]], -) -> building_blocks.Lambda: +) -> federated_language.framework.Lambda: """Turns `x -> ...only uses parts of x...` into `parts_of_x -> ...`. The names of locals in blocks are not modified, but unused block locals @@ -416,15 +433,16 @@ def as_function_of_some_subparameters( the returned computation will have unbound references. Args: - bb: Instance of `building_blocks.Lambda` that we wish to rewrite as a - function of some subparameters. + bb: Instance of `federated_language.framework.Lambda` that we wish to + rewrite as a function of some subparameters. paths: List of the paths representing the input subparameters to use. Each path is a tuple of ints (e.g. (5, 3) would represent a selection into the original arg like arg[5][3]). Note; it is not valid to specify overlapping selection paths (where one path encompasses another). Returns: - An instance of `building_blocks.Lambda` with a struct input parameter where + An instance of `federated_language.framework.Lambda` with a struct input + parameter where the ith element in the input parameter corresponds to the ith provided path. Raises: @@ -436,18 +454,18 @@ def _get_block_local_names(comp): names = [] def _visit(comp): - if isinstance(comp, building_blocks.Block): + if isinstance(comp, federated_language.framework.Block): for name, _ in comp.locals: names.append(name) - tree_analysis.visit_postorder(comp, _visit) + federated_language.framework.visit_postorder(comp, _visit) return names - tree_analysis.check_has_unique_names(bb) + federated_language.framework.check_has_unique_names(bb) original_local_names = _get_block_local_names(bb) bb, _ = remove_unused_block_locals(bb) - name_generator = building_block_factory.unique_name_generator(bb) + name_generator = federated_language.framework.unique_name_generator(bb) type_list = [] int_paths = [] @@ -455,7 +473,7 @@ def _visit(comp): selected_type = bb.parameter_type int_path = [] for index in path: - if not isinstance(selected_type, computation_types.StructType): + if not isinstance(selected_type, federated_language.StructType): raise ParameterSelectionError(path, bb) py_typecheck.check_type(index, int) if index >= len(selected_type): @@ -465,12 +483,12 @@ def _visit(comp): int_paths.append(tuple(int_path)) type_list.append(selected_type) - ref_to_struct = building_blocks.Reference( - next(name_generator), computation_types.StructType(type_list) + ref_to_struct = federated_language.framework.Reference( + next(name_generator), federated_language.StructType(type_list) ) path_to_replacement = {} for i, path in enumerate(int_paths): - path_to_replacement[path] = building_blocks.Selection( + path_to_replacement[path] = federated_language.framework.Selection( ref_to_struct, index=i ) @@ -478,9 +496,9 @@ def _visit(comp): bb.result, bb.parameter_name, path_to_replacement ) # Normalize the body so that it is a block. - if not isinstance(new_lambda_body, building_blocks.Block): - new_lambda_body = building_blocks.Block([], new_lambda_body) - lambda_with_zipped_param = building_blocks.Lambda( + if not isinstance(new_lambda_body, federated_language.framework.Block): + new_lambda_body = federated_language.framework.Block([], new_lambda_body) + lambda_with_zipped_param = federated_language.framework.Lambda( ref_to_struct.name, ref_to_struct.type_signature, new_lambda_body ) @@ -496,11 +514,11 @@ def strip_placement(comp): For this function to complete successfully `comp` must: 1) contain at most one federated placement. 2) not contain intrinsics besides `apply`, `map`, `zip`, and `federated_value` - 3) not contain `building_blocks.Data` of federated type. + 3) not contain `federated_language.framework.Data` of federated type. Args: - comp: Instance of `building_blocks.ComputationBuildingBlock` satisfying the - assumptions above. + comp: Instance of `federated_language.framework.ComputationBuildingBlock` + satisfying the assumptions above. Returns: A modified version of `comp` containing no intrinsics nor any federated @@ -510,9 +528,11 @@ def strip_placement(comp): TypeError: If `comp` is not a building block. ValueError: If conditions (1), (2), or (3) above are unsatisfied. """ - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + comp, federated_language.framework.ComputationBuildingBlock + ) placement = None - name_generator = building_block_factory.unique_name_generator(comp) + name_generator = federated_language.framework.unique_name_generator(comp) def _ensure_single_placement(new_placement): nonlocal placement @@ -527,7 +547,7 @@ def _ensure_single_placement(new_placement): ) def _remove_placement_from_type(type_spec): - if isinstance(type_spec, computation_types.FederatedType): + if isinstance(type_spec, federated_language.FederatedType): _ensure_single_placement(type_spec.placement) return type_spec.member, True else: @@ -535,33 +555,37 @@ def _remove_placement_from_type(type_spec): def _remove_reference_placement(comp): """Unwraps placement from references and updates unbound reference info.""" - new_type, _ = type_transformations.transform_type_postorder( + new_type, _ = federated_language.framework.transform_type_postorder( comp.type_signature, _remove_placement_from_type ) - return building_blocks.Reference(comp.name, new_type) + return federated_language.framework.Reference(comp.name, new_type) def _identity_function(arg_type): """Creates `lambda x: x` with argument type `arg_type`.""" arg_name = next(name_generator) - val = building_blocks.Reference(arg_name, arg_type) - lam = building_blocks.Lambda(arg_name, arg_type, val) + val = federated_language.framework.Reference(arg_name, arg_type) + lam = federated_language.framework.Lambda(arg_name, arg_type, val) return lam def _call_first_with_second_function(fn_type, arg_type): """Creates `lambda x: x[0](x[1])` with the provided .""" arg_name = next(name_generator) - tuple_ref = building_blocks.Reference(arg_name, [fn_type, arg_type]) - fn = building_blocks.Selection(tuple_ref, index=0) - arg = building_blocks.Selection(tuple_ref, index=1) - called_fn = building_blocks.Call(fn, arg) - return building_blocks.Lambda(arg_name, tuple_ref.type_signature, called_fn) + tuple_ref = federated_language.framework.Reference( + arg_name, [fn_type, arg_type] + ) + fn = federated_language.framework.Selection(tuple_ref, index=0) + arg = federated_language.framework.Selection(tuple_ref, index=1) + called_fn = federated_language.framework.Call(fn, arg) + return federated_language.framework.Lambda( + arg_name, tuple_ref.type_signature, called_fn + ) def _call_function(arg_type): """Creates `lambda x: x()` argument type `arg_type`.""" arg_name = next(name_generator) - arg_ref = building_blocks.Reference(arg_name, arg_type) - called_arg = building_blocks.Call(arg_ref, None) - return building_blocks.Lambda(arg_name, arg_type, called_arg) + arg_ref = federated_language.framework.Reference(arg_name, arg_type) + called_arg = federated_language.framework.Call(arg_ref, None) + return federated_language.framework.Lambda(arg_name, arg_type, called_arg) def _replace_intrinsics_with_functions(comp): """Helper to remove intrinsics from the AST.""" @@ -570,10 +594,10 @@ def _replace_intrinsics_with_functions(comp): # These functions have no runtime behavior and only exist to adjust # placement. They are replaced here with `lambda x: x`. identities = [ - intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri, - intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri, - intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri, - intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri, + federated_language.framework.FEDERATED_ZIP_AT_SERVER.uri, + federated_language.framework.FEDERATED_ZIP_AT_CLIENTS.uri, + federated_language.framework.FEDERATED_VALUE_AT_SERVER.uri, + federated_language.framework.FEDERATED_VALUE_AT_CLIENTS.uri, ] if comp.uri in identities: return _identity_function(tys.result.member) @@ -581,9 +605,9 @@ def _replace_intrinsics_with_functions(comp): # These functions all `map` a value and are replaced with # `lambda args: args[0](args[1]) maps = [ - intrinsic_defs.FEDERATED_MAP.uri, - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, - intrinsic_defs.FEDERATED_APPLY.uri, + federated_language.framework.FEDERATED_MAP.uri, + federated_language.framework.FEDERATED_MAP_ALL_EQUAL.uri, + federated_language.framework.FEDERATED_APPLY.uri, ] if comp.uri in maps: return _call_first_with_second_function( @@ -593,8 +617,8 @@ def _replace_intrinsics_with_functions(comp): # `federated_eval`'s argument must simply be `call`ed and is replaced # with `lambda x: x()` evals = [ - intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri, - intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS.uri, + federated_language.framework.FEDERATED_EVAL_AT_SERVER.uri, + federated_language.framework.FEDERATED_EVAL_AT_CLIENTS.uri, ] if comp.uri in evals: return _call_function(tys.parameter) @@ -606,64 +630,72 @@ def _remove_lambda_placement(comp): if comp.parameter_name is None: new_parameter_type = None else: - new_parameter_type, _ = type_transformations.transform_type_postorder( - comp.parameter_type, _remove_placement_from_type + new_parameter_type, _ = ( + federated_language.framework.transform_type_postorder( + comp.parameter_type, _remove_placement_from_type + ) ) - return building_blocks.Lambda( + return federated_language.framework.Lambda( comp.parameter_name, new_parameter_type, comp.result ) def _simplify_calls(comp): """Unwraps structures introduced by removing intrinsics.""" zip_or_value_removed = ( - isinstance(comp.function.result, building_blocks.Reference) + isinstance(comp.function.result, federated_language.framework.Reference) and comp.function.result.name == comp.function.parameter_name ) if zip_or_value_removed: return comp.argument else: map_removed = ( - isinstance(comp.function.result, building_blocks.Call) + isinstance(comp.function.result, federated_language.framework.Call) and isinstance( - comp.function.result.function, building_blocks.Selection + comp.function.result.function, + federated_language.framework.Selection, ) and comp.function.result.function.index == 0 and isinstance( - comp.function.result.argument, building_blocks.Selection + comp.function.result.argument, + federated_language.framework.Selection, ) and comp.function.result.argument.index == 1 and isinstance( - comp.function.result.function.source, building_blocks.Reference + comp.function.result.function.source, + federated_language.framework.Reference, ) and comp.function.result.function.source.name == comp.function.parameter_name and isinstance( - comp.function.result.function.source, building_blocks.Reference + comp.function.result.function.source, + federated_language.framework.Reference, ) and comp.function.result.function.source.name == comp.function.parameter_name - and isinstance(comp.argument, building_blocks.Struct) + and isinstance(comp.argument, federated_language.framework.Struct) ) if map_removed: - return building_blocks.Call(comp.argument[0], comp.argument[1]) + return federated_language.framework.Call( + comp.argument[0], comp.argument[1] + ) return comp def _transform(comp): """Dispatches to helpers above.""" - if isinstance(comp, building_blocks.Reference): + if isinstance(comp, federated_language.framework.Reference): return _remove_reference_placement(comp), True - elif isinstance(comp, building_blocks.Intrinsic): + elif isinstance(comp, federated_language.framework.Intrinsic): return _replace_intrinsics_with_functions(comp), True - elif isinstance(comp, building_blocks.Lambda): + elif isinstance(comp, federated_language.framework.Lambda): return _remove_lambda_placement(comp), True - elif isinstance(comp, building_blocks.Call) and isinstance( - comp.function, building_blocks.Lambda + elif isinstance(comp, federated_language.framework.Call) and isinstance( + comp.function, federated_language.framework.Lambda ): return _simplify_calls(comp), True - elif isinstance(comp, building_blocks.Data) and isinstance( - comp.type_signature, computation_types.FederatedType + elif isinstance(comp, federated_language.framework.Data) and isinstance( + comp.type_signature, federated_language.FederatedType ): raise ValueError(f'Cannot strip placement from federated data: {comp}') return comp, False - return transformation_utils.transform_postorder(comp, _transform) + return federated_language.framework.transform_postorder(comp, _transform) diff --git a/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py b/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py index 4d53ec5654..f357efa82f 100644 --- a/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py +++ b/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py @@ -16,20 +16,13 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.common_libs import golden from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.compiler import building_block_test_utils -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.compiler import transformation_utils -from tensorflow_federated.python.core.impl.compiler import tree_analysis from tensorflow_federated.python.core.impl.compiler import tree_transformations -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils class TransformTestBase(absltest.TestCase): @@ -45,7 +38,7 @@ def assert_transforms(self, comp, file, changes_type=False, unmodified=False): ), ) if not changes_type: - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( comp.type_signature, after.type_signature ) if unmodified: @@ -56,9 +49,13 @@ def assert_transforms(self, comp, file, changes_type=False, unmodified=False): def _create_chained_whimsy_federated_maps(functions, arg): - py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + arg, federated_language.framework.ComputationBuildingBlock + ) for fn in functions: - py_typecheck.check_type(fn, building_blocks.ComputationBuildingBlock) + py_typecheck.check_type( + fn, federated_language.framework.ComputationBuildingBlock + ) if not fn.parameter_type.is_assignable_from(arg.type_signature.member): raise TypeError( 'The parameter of the function is of type {}, and the argument is of ' @@ -66,7 +63,7 @@ def _create_chained_whimsy_federated_maps(functions, arg): str(fn.parameter_type), str(arg.type_signature.member) ) ) - call = building_block_factory.create_federated_map_all_equal(fn, arg) + call = federated_language.framework.create_federated_map_all_equal(fn, arg) arg = call return call @@ -80,12 +77,12 @@ def test_raises_type_error(self): @parameterized.named_parameters( ( 'federated_map', - intrinsic_defs.FEDERATED_MAP.uri, + federated_language.framework.FEDERATED_MAP.uri, building_block_test_utils.create_whimsy_called_federated_map, ), ( 'federated_map_all_equal', - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, + federated_language.framework.FEDERATED_MAP_ALL_EQUAL.uri, building_block_test_utils.create_whimsy_called_federated_map_all_equal, ), ) @@ -156,14 +153,14 @@ def test_removes_sequence_map(self): def test_removes_federated_map_with_named_result(self): parameter_type = [('a', np.int32), ('b', np.int32)] fn = building_block_test_utils.create_identity_function('c', parameter_type) - arg_type = computation_types.FederatedType( - parameter_type, placements.CLIENTS + arg_type = federated_language.FederatedType( + parameter_type, federated_language.CLIENTS ) any_proto = building_block_test_utils.create_any_proto_from_array( np.array(1, np.int32) ) - arg = building_blocks.Data(any_proto, arg_type) - call = building_block_factory.create_federated_map(fn, arg) + arg = federated_language.framework.Data(any_proto, arg_type) + call = federated_language.framework.create_federated_map(fn, arg) comp = call transformed_comp, modified = ( @@ -205,9 +202,11 @@ def test_removes_nested_federated_map(self): def test_removes_chained_federated_maps(self): fn = building_block_test_utils.create_identity_function('a', np.int32) - arg = building_block_factory.create_federated_value( - building_blocks.Literal(1, computation_types.TensorType(np.int32)), - placement=placements.CLIENTS, + arg = federated_language.framework.create_federated_value( + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), + placement=federated_language.CLIENTS, ) call = _create_chained_whimsy_federated_maps([fn, fn], arg) comp = call @@ -246,8 +245,10 @@ def test_does_not_remove_whimsy_intrinsic(self): def test_does_not_remove_called_lambda(self): fn = building_block_test_utils.create_identity_function('a', np.int32) - arg = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - call = building_blocks.Call(fn, arg) + arg = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + call = federated_language.framework.Call(fn, arg) comp = call transformed_comp, modified = ( @@ -269,31 +270,39 @@ def setUp(self): self._unused_block_remover = tree_transformations.RemoveUnusedBlockLocals() def test_should_transform_block(self): - blk = building_blocks.Block( + blk = federated_language.framework.Block( [( 'x', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), )], - building_blocks.Literal(2, computation_types.TensorType(np.int32)), + federated_language.framework.Literal( + 2, federated_language.TensorType(np.int32) + ), ) self.assertTrue(self._unused_block_remover.should_transform(blk)) def test_should_not_transform_data(self): - data = building_blocks.Literal(2, computation_types.TensorType(np.int32)) + data = federated_language.framework.Literal( + 2, federated_language.TensorType(np.int32) + ) self.assertFalse(self._unused_block_remover.should_transform(data)) def test_removes_block_with_unused_reference(self): - input_data = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) + input_data = federated_language.framework.Literal( + 2, federated_language.TensorType(np.int32) ) - blk = building_blocks.Block( + blk = federated_language.framework.Block( [( 'x', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), )], input_data, ) - data, modified = transformation_utils.transform_postorder( + data, modified = federated_language.framework.transform_postorder( blk, self._unused_block_remover.transform ) self.assertTrue(modified) @@ -302,11 +311,11 @@ def test_removes_block_with_unused_reference(self): ) def test_unwraps_block_with_empty_locals(self): - input_data = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + input_data = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) ) - blk = building_blocks.Block([], input_data) - data, modified = transformation_utils.transform_postorder( + blk = federated_language.framework.Block([], input_data) + data, modified = federated_language.framework.transform_postorder( blk, self._unused_block_remover.transform ) self.assertTrue(modified) @@ -315,18 +324,22 @@ def test_unwraps_block_with_empty_locals(self): ) def test_removes_nested_blocks_with_unused_reference(self): - input_data = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) + input_data = federated_language.framework.Literal( + 2, federated_language.TensorType(np.int32) ) - blk = building_blocks.Block( + blk = federated_language.framework.Block( [( 'x', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), )], input_data, ) - higher_level_blk = building_blocks.Block([('y', input_data)], blk) - data, modified = transformation_utils.transform_postorder( + higher_level_blk = federated_language.framework.Block( + [('y', input_data)], blk + ) + data, modified = federated_language.framework.transform_postorder( higher_level_blk, self._unused_block_remover.transform ) self.assertTrue(modified) @@ -335,15 +348,19 @@ def test_removes_nested_blocks_with_unused_reference(self): ) def test_leaves_single_used_reference(self): - blk = building_blocks.Block( + blk = federated_language.framework.Block( [( 'x', - building_blocks.Literal(1, computation_types.TensorType(np.int32)), + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), )], - building_blocks.Reference('x', np.int32), + federated_language.framework.Reference('x', np.int32), ) - transformed_blk, modified = transformation_utils.transform_postorder( - blk, self._unused_block_remover.transform + transformed_blk, modified = ( + federated_language.framework.transform_postorder( + blk, self._unused_block_remover.transform + ) ) self.assertFalse(modified) self.assertEqual( @@ -351,20 +368,22 @@ def test_leaves_single_used_reference(self): ) def test_leaves_chained_used_references(self): - blk = building_blocks.Block( + blk = federated_language.framework.Block( [ ( 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) ), ), - ('y', building_blocks.Reference('x', np.int32)), + ('y', federated_language.framework.Reference('x', np.int32)), ], - building_blocks.Reference('y', np.int32), + federated_language.framework.Reference('y', np.int32), ) - transformed_blk, modified = transformation_utils.transform_postorder( - blk, self._unused_block_remover.transform + transformed_blk, modified = ( + federated_language.framework.transform_postorder( + blk, self._unused_block_remover.transform + ) ) self.assertFalse(modified) self.assertEqual( @@ -374,23 +393,25 @@ def test_leaves_chained_used_references(self): def test_removes_locals_referencing_each_other_but_unreferenced_in_result( self, ): - input_data = building_blocks.Literal( - 2, computation_types.TensorType(np.int32) + input_data = federated_language.framework.Literal( + 2, federated_language.TensorType(np.int32) ) - blk = building_blocks.Block( + blk = federated_language.framework.Block( [ ( 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) ), ), - ('y', building_blocks.Reference('x', np.int32)), + ('y', federated_language.framework.Reference('x', np.int32)), ], input_data, ) - transformed_blk, modified = transformation_utils.transform_postorder( - blk, self._unused_block_remover.transform + transformed_blk, modified = ( + federated_language.framework.transform_postorder( + blk, self._unused_block_remover.transform + ) ) self.assertTrue(modified) self.assertEqual( @@ -399,26 +420,28 @@ def test_removes_locals_referencing_each_other_but_unreferenced_in_result( ) def test_leaves_lone_referenced_local(self): - ref = building_blocks.Reference('y', np.int32) - blk = building_blocks.Block( + ref = federated_language.framework.Reference('y', np.int32) + blk = federated_language.framework.Block( [ ( 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) ), ), ( 'y', - building_blocks.Literal( - 2, computation_types.TensorType(np.int32) + federated_language.framework.Literal( + 2, federated_language.TensorType(np.int32) ), ), ], ref, ) - transformed_blk, modified = transformation_utils.transform_postorder( - blk, self._unused_block_remover.transform + transformed_blk, modified = ( + federated_language.framework.transform_postorder( + blk, self._unused_block_remover.transform + ) ) self.assertTrue(modified) self.assertEqual(transformed_blk.compact_representation(), '(let y=2 in y)') @@ -436,10 +459,10 @@ def test_raises_type_error(self): def test_renames_lambda_but_not_unbound_reference_when_given_name_generator( self, ): - ref = building_blocks.Reference('x', np.int32) - lambda_binding_y = building_blocks.Lambda('y', np.float32, ref) + ref = federated_language.framework.Reference('x', np.int32) + lambda_binding_y = federated_language.framework.Lambda('y', np.float32, ref) - name_generator = building_block_factory.unique_name_generator( + name_generator = federated_language.framework.unique_name_generator( lambda_binding_y ) transformed_comp, modified = tree_transformations.uniquify_reference_names( @@ -454,9 +477,13 @@ def test_renames_lambda_but_not_unbound_reference_when_given_name_generator( self.assertTrue(modified) def test_single_level_block(self): - ref = building_blocks.Reference('a', np.int32) - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block = building_blocks.Block((('a', lit), ('a', ref), ('a', ref)), ref) + ref = federated_language.framework.Reference('a', np.int32) + lit = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + block = federated_language.framework.Block( + (('a', lit), ('a', ref), ('a', ref)), ref + ) transformed_comp, modified = tree_transformations.uniquify_reference_names( block @@ -467,14 +494,20 @@ def test_single_level_block(self): transformed_comp.compact_representation(), '(let a=1,_var1=a,_var2=_var1 in _var2)', ) - tree_analysis.check_has_unique_names(transformed_comp) + federated_language.framework.check_has_unique_names(transformed_comp) self.assertTrue(modified) def test_nested_blocks(self): - x_ref = building_blocks.Reference('a', np.int32) - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block1 = building_blocks.Block([('a', lit), ('a', x_ref)], x_ref) - block2 = building_blocks.Block([('a', lit), ('a', x_ref)], block1) + x_ref = federated_language.framework.Reference('a', np.int32) + lit = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + block1 = federated_language.framework.Block( + [('a', lit), ('a', x_ref)], x_ref + ) + block2 = federated_language.framework.Block( + [('a', lit), ('a', x_ref)], block1 + ) transformed_comp, modified = tree_transformations.uniquify_reference_names( block2 @@ -488,18 +521,23 @@ def test_nested_blocks(self): transformed_comp.compact_representation(), '(let a=1,_var1=a in (let _var2=1,_var3=_var2 in _var3))', ) - tree_analysis.check_has_unique_names(transformed_comp) + federated_language.framework.check_has_unique_names(transformed_comp) self.assertTrue(modified) def test_nested_lambdas(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - input1 = building_blocks.Reference('a', lit.type_signature) - first_level_call = building_blocks.Call( - building_blocks.Lambda('a', input1.type_signature, input1), lit - ) - input2 = building_blocks.Reference('b', first_level_call.type_signature) - second_level_call = building_blocks.Call( - building_blocks.Lambda('b', input2.type_signature, input2), + lit = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + input1 = federated_language.framework.Reference('a', lit.type_signature) + first_level_call = federated_language.framework.Call( + federated_language.framework.Lambda('a', input1.type_signature, input1), + lit, + ) + input2 = federated_language.framework.Reference( + 'b', first_level_call.type_signature + ) + second_level_call = federated_language.framework.Call( + federated_language.framework.Lambda('b', input2.type_signature, input2), first_level_call, ) @@ -510,20 +548,26 @@ def test_nested_lambdas(self): self.assertEqual( transformed_comp.compact_representation(), '(b -> b)((a -> a)(1))' ) - tree_analysis.check_has_unique_names(transformed_comp) + federated_language.framework.check_has_unique_names(transformed_comp) self.assertFalse(modified) def test_block_lambda_block_lambda(self): - x_ref = building_blocks.Reference('a', np.int32) - inner_lambda = building_blocks.Lambda('a', np.int32, x_ref) - called_lambda = building_blocks.Call(inner_lambda, x_ref) - lower_block = building_blocks.Block( + x_ref = federated_language.framework.Reference('a', np.int32) + inner_lambda = federated_language.framework.Lambda('a', np.int32, x_ref) + called_lambda = federated_language.framework.Call(inner_lambda, x_ref) + lower_block = federated_language.framework.Block( [('a', x_ref), ('a', x_ref)], called_lambda ) - second_lambda = building_blocks.Lambda('a', np.int32, lower_block) - second_call = building_blocks.Call(second_lambda, x_ref) - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - last_block = building_blocks.Block([('a', lit), ('a', x_ref)], second_call) + second_lambda = federated_language.framework.Lambda( + 'a', np.int32, lower_block + ) + second_call = federated_language.framework.Call(second_lambda, x_ref) + lit = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + last_block = federated_language.framework.Block( + [('a', lit), ('a', x_ref)], second_call + ) transformed_comp, modified = tree_transformations.uniquify_reference_names( last_block @@ -540,23 +584,29 @@ def test_block_lambda_block_lambda(self): ' (_var5 -> _var5)(_var4)))(_var1))' ), ) - tree_analysis.check_has_unique_names(transformed_comp) + federated_language.framework.check_has_unique_names(transformed_comp) self.assertTrue(modified) def test_blocks_nested_inside_of_locals(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - lower_block = building_blocks.Block([('a', lit)], lit) - middle_block = building_blocks.Block([('a', lower_block)], lit) - higher_block = building_blocks.Block([('a', middle_block)], lit) - y_ref = building_blocks.Reference('a', np.int32) - lower_block_with_y_ref = building_blocks.Block([('a', y_ref)], lit) - middle_block_with_y_ref = building_blocks.Block( + lit = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + lower_block = federated_language.framework.Block([('a', lit)], lit) + middle_block = federated_language.framework.Block([('a', lower_block)], lit) + higher_block = federated_language.framework.Block( + [('a', middle_block)], lit + ) + y_ref = federated_language.framework.Reference('a', np.int32) + lower_block_with_y_ref = federated_language.framework.Block( + [('a', y_ref)], lit + ) + middle_block_with_y_ref = federated_language.framework.Block( [('a', lower_block_with_y_ref)], lit ) - higher_block_with_y_ref = building_blocks.Block( + higher_block_with_y_ref = federated_language.framework.Block( [('a', middle_block_with_y_ref)], lit ) - multiple_bindings_highest_block = building_blocks.Block( + multiple_bindings_highest_block = federated_language.framework.Block( [('a', higher_block), ('a', higher_block_with_y_ref)], higher_block_with_y_ref, ) @@ -565,11 +615,13 @@ def test_blocks_nested_inside_of_locals(self): multiple_bindings_highest_block, 'uniquify_names_blocks_nested_inside_of_locals.expected', ) - tree_analysis.check_has_unique_names(transformed_comp) + federated_language.framework.check_has_unique_names(transformed_comp) def test_keeps_existing_nonoverlapping_names(self): - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - block = building_blocks.Block([('a', lit), ('b', lit)], lit) + lit = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + block = federated_language.framework.Block([('a', lit), ('b', lit)], lit) comp = block transformed_comp, modified = tree_transformations.uniquify_reference_names( @@ -590,68 +642,76 @@ def test_raises_on_none(self): tree_transformations.normalize_types(None) def test_ignore_unnormalized_all_equal(self): - fed_type_all_equal = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + fed_type_all_equal = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) unnormalized_comp = tree_transformations.normalize_types( - building_blocks.Reference('x', fed_type_all_equal), + federated_language.framework.Reference('x', fed_type_all_equal), normalize_all_equal_bit=False, ) self.assertEqual(unnormalized_comp.type_signature, fed_type_all_equal) - self.assertIsInstance(unnormalized_comp, building_blocks.Reference) + self.assertIsInstance( + unnormalized_comp, federated_language.framework.Reference + ) self.assertEqual(str(unnormalized_comp), 'x') def test_converts_all_equal_at_clients_reference_to_not_equal(self): - fed_type_all_equal = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + fed_type_all_equal = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) normalized_comp = tree_transformations.normalize_types( - building_blocks.Reference('x', fed_type_all_equal) + federated_language.framework.Reference('x', fed_type_all_equal) ) self.assertEqual( normalized_comp.type_signature, - computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False + federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=False ), ) - self.assertIsInstance(normalized_comp, building_blocks.Reference) + self.assertIsInstance( + normalized_comp, federated_language.framework.Reference + ) self.assertEqual(str(normalized_comp), 'x') def test_converts_not_all_equal_at_server_reference_to_equal(self): - fed_type_not_all_equal = computation_types.FederatedType( - np.int32, placements.SERVER, all_equal=False + fed_type_not_all_equal = federated_language.FederatedType( + np.int32, federated_language.SERVER, all_equal=False ) normalized_comp = tree_transformations.normalize_types( - building_blocks.Reference('x', fed_type_not_all_equal) + federated_language.framework.Reference('x', fed_type_not_all_equal) ) self.assertEqual( normalized_comp.type_signature, - computation_types.FederatedType( - np.int32, placements.SERVER, all_equal=True + federated_language.FederatedType( + np.int32, federated_language.SERVER, all_equal=True ), ) - self.assertIsInstance(normalized_comp, building_blocks.Reference) + self.assertIsInstance( + normalized_comp, federated_language.framework.Reference + ) self.assertEqual(str(normalized_comp), 'x') def test_converts_all_equal_at_clients_lambda_parameter_to_not_equal(self): - fed_type_all_equal = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + fed_type_all_equal = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) - normalized_fed_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + normalized_fed_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - ref = building_blocks.Reference('x', fed_type_all_equal) - lam = building_blocks.Lambda('x', fed_type_all_equal, ref) + ref = federated_language.framework.Reference('x', fed_type_all_equal) + lam = federated_language.framework.Lambda('x', fed_type_all_equal, ref) normalized_lambda = tree_transformations.normalize_types(lam) self.assertEqual( lam.type_signature, - computation_types.FunctionType(fed_type_all_equal, fed_type_all_equal), + federated_language.FunctionType(fed_type_all_equal, fed_type_all_equal), + ) + self.assertIsInstance( + normalized_lambda, federated_language.framework.Lambda ) - self.assertIsInstance(normalized_lambda, building_blocks.Lambda) self.assertEqual(str(normalized_lambda), '(x -> x)') self.assertEqual( normalized_lambda.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( normalized_fed_type, normalized_fed_type ), ) @@ -659,18 +719,18 @@ def test_converts_all_equal_at_clients_lambda_parameter_to_not_equal(self): def test_converts_all_equal_at_clients_lambda_struct_parameter_to_not_equal( self, ): - fed_type_all_equal = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + fed_type_all_equal = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) - normalized_fed_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + normalized_fed_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', - computation_types.StructType([fed_type_all_equal, fed_type_all_equal]), - building_blocks.Reference( + federated_language.StructType([fed_type_all_equal, fed_type_all_equal]), + federated_language.framework.Reference( 'x', - computation_types.StructType( + federated_language.StructType( [fed_type_all_equal, fed_type_all_equal] ), ), @@ -678,24 +738,26 @@ def test_converts_all_equal_at_clients_lambda_struct_parameter_to_not_equal( normalized_lambda = tree_transformations.normalize_types(lam) self.assertEqual( lam.type_signature, - computation_types.FunctionType( - computation_types.StructType( + federated_language.FunctionType( + federated_language.StructType( [fed_type_all_equal, fed_type_all_equal] ), - computation_types.StructType( + federated_language.StructType( [fed_type_all_equal, fed_type_all_equal] ), ), ) - self.assertIsInstance(normalized_lambda, building_blocks.Lambda) + self.assertIsInstance( + normalized_lambda, federated_language.framework.Lambda + ) self.assertEqual(str(normalized_lambda), '(x -> x)') self.assertEqual( normalized_lambda.type_signature, - computation_types.FunctionType( - computation_types.StructType( + federated_language.FunctionType( + federated_language.StructType( [normalized_fed_type, normalized_fed_type] ), - computation_types.StructType( + federated_language.StructType( [normalized_fed_type, normalized_fed_type] ), ), @@ -704,93 +766,99 @@ def test_converts_all_equal_at_clients_lambda_struct_parameter_to_not_equal( def test_converts_all_equal_at_clients_lambda_nested_struct_parameter_to_not_equal( self, ): - fed_type_all_equal = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + fed_type_all_equal = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) - normalized_fed_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + normalized_fed_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( 'x', - computation_types.StructType([ + federated_language.StructType([ fed_type_all_equal, - computation_types.StructType([fed_type_all_equal]), + federated_language.StructType([fed_type_all_equal]), ]), - building_blocks.Reference( + federated_language.framework.Reference( 'x', - computation_types.StructType([ + federated_language.StructType([ fed_type_all_equal, - computation_types.StructType([fed_type_all_equal]), + federated_language.StructType([fed_type_all_equal]), ]), ), ) normalized_lambda = tree_transformations.normalize_types(lam) self.assertEqual( lam.type_signature, - computation_types.FunctionType( - computation_types.StructType([ + federated_language.FunctionType( + federated_language.StructType([ fed_type_all_equal, - computation_types.StructType([fed_type_all_equal]), + federated_language.StructType([fed_type_all_equal]), ]), - computation_types.StructType([ + federated_language.StructType([ fed_type_all_equal, - computation_types.StructType([fed_type_all_equal]), + federated_language.StructType([fed_type_all_equal]), ]), ), ) - self.assertIsInstance(normalized_lambda, building_blocks.Lambda) + self.assertIsInstance( + normalized_lambda, federated_language.framework.Lambda + ) self.assertEqual(str(normalized_lambda), '(x -> x)') self.assertEqual( normalized_lambda.type_signature, - computation_types.FunctionType( - computation_types.StructType([ + federated_language.FunctionType( + federated_language.StructType([ normalized_fed_type, - computation_types.StructType([normalized_fed_type]), + federated_language.StructType([normalized_fed_type]), ]), - computation_types.StructType([ + federated_language.StructType([ normalized_fed_type, - computation_types.StructType([normalized_fed_type]), + federated_language.StructType([normalized_fed_type]), ]), ), ) def test_converts_not_all_equal_at_server_lambda_parameter_to_equal(self): - fed_type_not_all_equal = computation_types.FederatedType( - np.int32, placements.SERVER, all_equal=False + fed_type_not_all_equal = federated_language.FederatedType( + np.int32, federated_language.SERVER, all_equal=False ) - normalized_fed_type = computation_types.FederatedType( - np.int32, placements.SERVER + normalized_fed_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - ref = building_blocks.Reference('x', fed_type_not_all_equal) - lam = building_blocks.Lambda('x', fed_type_not_all_equal, ref) + ref = federated_language.framework.Reference('x', fed_type_not_all_equal) + lam = federated_language.framework.Lambda('x', fed_type_not_all_equal, ref) normalized_lambda = tree_transformations.normalize_types(lam) self.assertEqual( lam.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( fed_type_not_all_equal, fed_type_not_all_equal ), ) - self.assertIsInstance(normalized_lambda, building_blocks.Lambda) + self.assertIsInstance( + normalized_lambda, federated_language.framework.Lambda + ) self.assertEqual(str(normalized_lambda), '(x -> x)') self.assertEqual( normalized_lambda.type_signature, - computation_types.FunctionType( + federated_language.FunctionType( normalized_fed_type, normalized_fed_type ), ) def test_converts_federated_map_all_equal_to_federated_map(self): - fed_type_all_equal = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + fed_type_all_equal = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) - normalized_fed_type = computation_types.FederatedType( - np.int32, placements.CLIENTS + normalized_fed_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + int_ref = federated_language.framework.Reference('x', np.int32) + int_identity = federated_language.framework.Lambda('x', np.int32, int_ref) + federated_int_ref = federated_language.framework.Reference( + 'y', fed_type_all_equal ) - int_ref = building_blocks.Reference('x', np.int32) - int_identity = building_blocks.Lambda('x', np.int32, int_ref) - federated_int_ref = building_blocks.Reference('y', fed_type_all_equal) called_federated_map_all_equal = ( - building_block_factory.create_federated_map_all_equal( + federated_language.framework.create_federated_map_all_equal( int_identity, federated_int_ref ) ) @@ -799,14 +867,18 @@ def test_converts_federated_map_all_equal_to_federated_map(self): ) self.assertEqual( called_federated_map_all_equal.function.uri, - intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri, + federated_language.framework.FEDERATED_MAP_ALL_EQUAL.uri, + ) + self.assertIsInstance( + normalized_federated_map, federated_language.framework.Call ) - self.assertIsInstance(normalized_federated_map, building_blocks.Call) self.assertIsInstance( - normalized_federated_map.function, building_blocks.Intrinsic + normalized_federated_map.function, + federated_language.framework.Intrinsic, ) self.assertEqual( - normalized_federated_map.function.uri, intrinsic_defs.FEDERATED_MAP.uri + normalized_federated_map.function.uri, + federated_language.framework.FEDERATED_MAP.uri, ) self.assertEqual( normalized_federated_map.type_signature, normalized_fed_type @@ -816,10 +888,11 @@ def test_converts_federated_map_all_equal_to_federated_map(self): class ReplaceSelectionsTest(absltest.TestCase): def test_replace_selection(self): - comp = building_blocks.Selection( - building_blocks.Reference('x', [np.int32, np.int32]), index=1 + comp = federated_language.framework.Selection( + federated_language.framework.Reference('x', [np.int32, np.int32]), + index=1, ) - y = building_blocks.Reference('y', np.int32) + y = federated_language.framework.Reference('y', np.int32) path_to_replacement = { (1,): y, } @@ -829,19 +902,22 @@ def test_replace_selection(self): self.assertEqual(new_comp.proto, y.proto) def test_replace_multiple_instances_of_selection(self): - comp = building_blocks.Struct([ - building_blocks.Selection( - building_blocks.Reference('x', [np.int32, [np.int32]]), index=1 + comp = federated_language.framework.Struct([ + federated_language.framework.Selection( + federated_language.framework.Reference('x', [np.int32, [np.int32]]), + index=1, ), - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', [np.int32, [np.int32]]), + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference( + 'x', [np.int32, [np.int32]] + ), index=1, ), index=0, ), ]) - y = building_blocks.Reference('y', [np.int32]) + y = federated_language.framework.Reference('y', [np.int32]) path_to_replacement = { (1,): y, } @@ -850,16 +926,17 @@ def test_replace_multiple_instances_of_selection(self): ) self.assertEqual( new_comp.proto, - building_blocks.Struct( - [y, building_blocks.Selection(y, index=0)] + federated_language.framework.Struct( + [y, federated_language.framework.Selection(y, index=0)] ).proto, ) def test_replace_selection_mismatching_ref_name(self): - comp = building_blocks.Selection( - building_blocks.Reference('x', [np.int32, np.int32]), index=1 + comp = federated_language.framework.Selection( + federated_language.framework.Reference('x', [np.int32, np.int32]), + index=1, ) - y = building_blocks.Reference('y', np.int32) + y = federated_language.framework.Reference('y', np.int32) path_to_replacement = { (1,): y, } @@ -869,19 +946,23 @@ def test_replace_selection_mismatching_ref_name(self): self.assertEqual(new_comp.proto, comp.proto) def test_fail_replace_compiled_comp(self): - arg_type = computation_types.StructType([np.int32]) - fn_type = computation_types.FunctionType(arg_type, arg_type) + arg_type = federated_language.StructType([np.int32]) + fn_type = federated_language.FunctionType(arg_type, arg_type) mock_fn = mock.create_autospec( - building_blocks.CompiledComputation, spec_set=True, instance=True + federated_language.framework.CompiledComputation, + spec_set=True, + instance=True, ) type(mock_fn).type_signature = mock.PropertyMock( - spec=computation_types.FunctionType, return_value=fn_type, spec_set=True + spec=federated_language.FunctionType, + return_value=fn_type, + spec_set=True, ) - comp = building_blocks.Call( + comp = federated_language.framework.Call( mock_fn, - building_blocks.Reference('x', arg_type), + federated_language.framework.Reference('x', arg_type), ) - y = building_blocks.Reference('y', np.int32) + y = federated_language.framework.Reference('y', np.int32) path_to_replacement = { (0,): y, } @@ -889,13 +970,13 @@ def test_fail_replace_compiled_comp(self): tree_transformations.replace_selections(comp, 'x', path_to_replacement) def test_no_subsequent_replacement(self): - comp = building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', [[np.int32]]), index=0 + comp = federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('x', [[np.int32]]), index=0 ), index=0, ) - replacement = building_blocks.Reference('x', [np.int32]) + replacement = federated_language.framework.Reference('x', [np.int32]) path_to_replacement = { (0,): replacement, } @@ -907,8 +988,8 @@ def test_no_subsequent_replacement(self): # type signatures would not be accurate. self.assertEqual( new_comp.proto, - building_blocks.Selection( - building_blocks.Reference('x', [np.int32]), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference('x', [np.int32]), index=0 ).proto, ) @@ -916,39 +997,43 @@ def test_no_subsequent_replacement(self): class AsFunctionOfSomeParametersTest(absltest.TestCase): def test_empty_path(self): - comp = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + comp = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) ) new_comp = tree_transformations.as_function_of_some_subparameters(comp, []) - self.assertEqual(new_comp.parameter_type, computation_types.StructType([])) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] + self.assertEqual(new_comp.parameter_type, federated_language.StructType([])) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] + ) self.assertEqual(unbound_references, set(['x'])) def test_all_path(self): - comp = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) + comp = federated_language.framework.Lambda( + 'x', np.int32, federated_language.framework.Reference('x', np.int32) ) new_comp = tree_transformations.as_function_of_some_subparameters( comp, [()] ) self.assertEqual( - new_comp.parameter_type, computation_types.StructType([np.int32]) + new_comp.parameter_type, federated_language.StructType([np.int32]) + ) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] self.assertEmpty(unbound_references) def test_selection_path(self): arg_type = [[np.int32]] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'x', arg_type, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=0 + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=0 ), index=0, ), @@ -957,41 +1042,45 @@ def test_selection_path(self): comp, [(0, 0)] ) self.assertEqual( - new_comp.parameter_type, computation_types.StructType([np.int32]) + new_comp.parameter_type, federated_language.StructType([np.int32]) + ) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] self.assertEmpty(unbound_references) def test_partial_selection_path(self): arg_type = [[np.int32]] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'x', arg_type, - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=0 + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=0 ), ) new_comp = tree_transformations.as_function_of_some_subparameters( comp, [(0,)] ) self.assertEqual( - new_comp.parameter_type, computation_types.StructType([[np.int32]]) + new_comp.parameter_type, federated_language.StructType([[np.int32]]) + ) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] self.assertEmpty(unbound_references) def test_invalid_selection_path(self): arg_type = [[np.int32]] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'x', arg_type, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=0 + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=0 ), index=0, ), @@ -1001,21 +1090,22 @@ def test_invalid_selection_path(self): def test_multiple_selection_path(self): arg_type = [np.int32, np.float32, [np.int32, np.str_]] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'x', arg_type, - building_blocks.Struct([ - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=1 + federated_language.framework.Struct([ + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=1 ), - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=2 + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), + index=2, ), index=0, ), - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=2 + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=2 ), ]), ) @@ -1024,20 +1114,22 @@ def test_multiple_selection_path(self): ) self.assertEqual( new_comp.parameter_type, - computation_types.StructType([np.float32, [np.int32, np.str_]]), + federated_language.StructType([np.float32, [np.int32, np.str_]]), + ) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] self.assertEmpty(unbound_references) def test_unused_selection_path(self): arg_type = [np.int32, np.float32, [np.int32, np.str_]] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'x', arg_type, - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=1 + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=1 ), ) new_comp = tree_transformations.as_function_of_some_subparameters( @@ -1045,21 +1137,23 @@ def test_unused_selection_path(self): ) self.assertEqual( new_comp.parameter_type, - computation_types.StructType([np.float32, [np.int32, np.str_]]), + federated_language.StructType([np.float32, [np.int32, np.str_]]), + ) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] self.assertEmpty(unbound_references) def test_paths_not_applied_sequentially(self): arg_type = [np.int32, np.float32, [np.int32, np.str_]] - comp = building_blocks.Lambda( + comp = federated_language.framework.Lambda( 'x', arg_type, - building_blocks.Selection( - building_blocks.Selection( - building_blocks.Reference('x', arg_type), index=2 + federated_language.framework.Selection( + federated_language.framework.Selection( + federated_language.framework.Reference('x', arg_type), index=2 ), index=1, ), @@ -1070,15 +1164,19 @@ def test_paths_not_applied_sequentially(self): ) self.assertEqual( new_comp.parameter_type, - computation_types.StructType([[np.int32, np.str_], np.float32]), + federated_language.StructType([[np.int32, np.str_], np.float32]), + ) + unbound_references = ( + federated_language.framework.get_map_of_unbound_references(new_comp)[ + new_comp + ] ) - unbound_references = transformation_utils.get_map_of_unbound_references( - new_comp - )[new_comp] self.assertEmpty(unbound_references) - self.assertIsInstance(new_comp.result.result, building_blocks.Selection) self.assertIsInstance( - new_comp.result.result.source, building_blocks.Selection + new_comp.result.result, federated_language.framework.Selection + ) + self.assertIsInstance( + new_comp.result.result.source, federated_language.framework.Selection ) @@ -1086,100 +1184,110 @@ class StripPlacementTest(parameterized.TestCase): def assert_has_no_intrinsics_nor_federated_types(self, comp): def _check(x): - if isinstance(x.type_signature, computation_types.FederatedType): + if isinstance(x.type_signature, federated_language.FederatedType): raise AssertionError(f'Unexpected federated type: {x.type_signature}') - if isinstance(x, building_blocks.Intrinsic): + if isinstance(x, federated_language.framework.Intrinsic): raise AssertionError(f'Unexpected intrinsic: {x}') - tree_analysis.visit_postorder(comp, _check) + federated_language.framework.visit_postorder(comp, _check) def test_raises_on_none(self): with self.assertRaises(TypeError): tree_transformations.strip_placement(None) def test_computation_non_federated_type(self): - before = building_blocks.Literal(1, computation_types.TensorType(np.int32)) + before = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) after, modified = tree_transformations.strip_placement(before) self.assertEqual(before, after) self.assertFalse(modified) def test_raises_disallowed_intrinsic(self): - fed_ref = building_blocks.Reference( - 'x', computation_types.FederatedType(np.int32, placements.SERVER) + fed_ref = federated_language.framework.Reference( + 'x', + federated_language.FederatedType(np.int32, federated_language.SERVER), ) - broadcaster = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_BROADCAST.uri, - computation_types.FunctionType( + broadcaster = federated_language.framework.Intrinsic( + federated_language.framework.FEDERATED_BROADCAST.uri, + federated_language.FunctionType( fed_ref.type_signature, - computation_types.FederatedType( + federated_language.FederatedType( fed_ref.type_signature.member, - placements.CLIENTS, + federated_language.CLIENTS, all_equal=True, ), ), ) - called_broadcast = building_blocks.Call(broadcaster, fed_ref) + called_broadcast = federated_language.framework.Call(broadcaster, fed_ref) with self.assertRaises(ValueError): tree_transformations.strip_placement(called_broadcast) def test_raises_multiple_placements(self): - server_placed_data = building_blocks.Reference( - 'x', computation_types.FederatedType(np.int32, placements.SERVER) + server_placed_data = federated_language.framework.Reference( + 'x', + federated_language.FederatedType(np.int32, federated_language.SERVER), ) - clients_placed_data = building_blocks.Reference( - 'y', computation_types.FederatedType(np.int32, placements.CLIENTS) + clients_placed_data = federated_language.framework.Reference( + 'y', + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) - block_holding_both = building_blocks.Block( + block_holding_both = federated_language.framework.Block( [('x', server_placed_data)], clients_placed_data ) with self.assertRaisesRegex(ValueError, 'multiple different placements'): tree_transformations.strip_placement(block_holding_both) def test_passes_unbound_type_signature_obscured_under_block(self): - fed_ref = building_blocks.Reference( - 'x', computation_types.FederatedType(np.int32, placements.SERVER) + fed_ref = federated_language.framework.Reference( + 'x', + federated_language.FederatedType(np.int32, federated_language.SERVER), ) - block = building_blocks.Block( + block = federated_language.framework.Block( [ ('y', fed_ref), ( 'x', - building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) ), ), - ('z', building_blocks.Reference('x', np.int32)), + ('z', federated_language.framework.Reference('x', np.int32)), ], - building_blocks.Reference('y', fed_ref.type_signature), + federated_language.framework.Reference('y', fed_ref.type_signature), ) tree_transformations.strip_placement(block) def test_passes_noarg_lambda(self): - lam = building_blocks.Lambda( + lam = federated_language.framework.Lambda( None, None, - building_blocks.Literal(1, computation_types.TensorType(np.int32)), + federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ), + ) + fed_int_type = federated_language.FederatedType( + np.int32, federated_language.SERVER ) - fed_int_type = computation_types.FederatedType(np.int32, placements.SERVER) - fed_eval = building_blocks.Intrinsic( - intrinsic_defs.FEDERATED_EVAL_AT_SERVER.uri, - computation_types.FunctionType(lam.type_signature, fed_int_type), + fed_eval = federated_language.framework.Intrinsic( + federated_language.framework.FEDERATED_EVAL_AT_SERVER.uri, + federated_language.FunctionType(lam.type_signature, fed_int_type), ) - called_eval = building_blocks.Call(fed_eval, lam) + called_eval = federated_language.framework.Call(fed_eval, lam) tree_transformations.strip_placement(called_eval) def test_removes_federated_types_under_function(self): int_type = np.int32 - server_int_type = computation_types.FederatedType( - int_type, placements.SERVER + server_int_type = federated_language.FederatedType( + int_type, federated_language.SERVER ) - int_ref = building_blocks.Reference('x', int_type) - int_id = building_blocks.Lambda('x', int_type, int_ref) - fed_ref = building_blocks.Reference('x', server_int_type) - applied_id = building_block_factory.create_federated_map_or_apply( + int_ref = federated_language.framework.Reference('x', int_type) + int_id = federated_language.framework.Lambda('x', int_type, int_ref) + fed_ref = federated_language.framework.Reference('x', server_int_type) + applied_id = federated_language.framework.create_federated_map_or_apply( int_id, fed_ref ) - before = building_block_factory.create_federated_map_or_apply( + before = federated_language.framework.create_federated_map_or_apply( int_id, applied_id ) after, modified = tree_transformations.strip_placement(before) @@ -1187,26 +1295,28 @@ def test_removes_federated_types_under_function(self): self.assert_has_no_intrinsics_nor_federated_types(after) def test_strip_placement_removes_federated_applys(self): - int_type = computation_types.TensorType(np.int32) - server_int_type = computation_types.FederatedType( - int_type, placements.SERVER - ) - int_ref = building_blocks.Reference('x', int_type) - int_id = building_blocks.Lambda('x', int_type, int_ref) - fed_ref = building_blocks.Reference('x', server_int_type) - applied_id = building_block_factory.create_federated_map_or_apply( + int_type = federated_language.TensorType(np.int32) + server_int_type = federated_language.FederatedType( + int_type, federated_language.SERVER + ) + int_ref = federated_language.framework.Reference('x', int_type) + int_id = federated_language.framework.Lambda('x', int_type, int_ref) + fed_ref = federated_language.framework.Reference('x', server_int_type) + applied_id = federated_language.framework.create_federated_map_or_apply( int_id, fed_ref ) - before = building_block_factory.create_federated_map_or_apply( + before = federated_language.framework.create_federated_map_or_apply( int_id, applied_id ) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, server_int_type ) - type_test_utils.assert_types_identical(after.type_signature, int_type) + federated_language.framework.assert_types_identical( + after.type_signature, int_type + ) self.assertEqual( before.compact_representation(), 'federated_apply(<(x -> x),federated_apply(<(x -> x),x>)>)', @@ -1214,26 +1324,28 @@ def test_strip_placement_removes_federated_applys(self): self.assertEqual(after.compact_representation(), '(x -> x)((x -> x)(x))') def test_strip_placement_removes_federated_maps(self): - int_type = computation_types.TensorType(np.int32) - clients_int_type = computation_types.FederatedType( - int_type, placements.CLIENTS - ) - int_ref = building_blocks.Reference('x', int_type) - int_id = building_blocks.Lambda('x', int_type, int_ref) - fed_ref = building_blocks.Reference('x', clients_int_type) - applied_id = building_block_factory.create_federated_map_or_apply( + int_type = federated_language.TensorType(np.int32) + clients_int_type = federated_language.FederatedType( + int_type, federated_language.CLIENTS + ) + int_ref = federated_language.framework.Reference('x', int_type) + int_id = federated_language.framework.Lambda('x', int_type, int_ref) + fed_ref = federated_language.framework.Reference('x', clients_int_type) + applied_id = federated_language.framework.create_federated_map_or_apply( int_id, fed_ref ) - before = building_block_factory.create_federated_map_or_apply( + before = federated_language.framework.create_federated_map_or_apply( int_id, applied_id ) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, clients_int_type ) - type_test_utils.assert_types_identical(after.type_signature, int_type) + federated_language.framework.assert_types_identical( + after.type_signature, int_type + ) self.assertEqual( before.compact_representation(), 'federated_map(<(x -> x),federated_map(<(x -> x),x>)>)', @@ -1241,129 +1353,149 @@ def test_strip_placement_removes_federated_maps(self): self.assertEqual(after.compact_representation(), '(x -> x)((x -> x)(x))') def test_unwrap_removes_federated_zips_at_server(self): - list_type = computation_types.StructType([np.int32, np.float32] * 2) - server_list_type = computation_types.FederatedType( - list_type, placements.SERVER + list_type = federated_language.StructType([np.int32, np.float32] * 2) + server_list_type = federated_language.FederatedType( + list_type, federated_language.SERVER ) - fed_tuple = building_blocks.Reference('tup', server_list_type) - unzipped = building_block_factory.create_federated_unzip(fed_tuple) - before = building_block_factory.create_federated_zip(unzipped) + fed_tuple = federated_language.framework.Reference('tup', server_list_type) + unzipped = federated_language.framework.create_federated_unzip(fed_tuple) + before = federated_language.framework.create_federated_zip(unzipped) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, server_list_type ) - type_test_utils.assert_types_identical(after.type_signature, list_type) + federated_language.framework.assert_types_identical( + after.type_signature, list_type + ) def test_unwrap_removes_federated_zips_at_clients(self): - list_type = computation_types.StructType([np.int32, np.float32] * 2) - clients_list_type = computation_types.FederatedType( - list_type, placements.SERVER + list_type = federated_language.StructType([np.int32, np.float32] * 2) + clients_list_type = federated_language.FederatedType( + list_type, federated_language.SERVER ) - fed_tuple = building_blocks.Reference('tup', clients_list_type) - unzipped = building_block_factory.create_federated_unzip(fed_tuple) - before = building_block_factory.create_federated_zip(unzipped) + fed_tuple = federated_language.framework.Reference('tup', clients_list_type) + unzipped = federated_language.framework.create_federated_unzip(fed_tuple) + before = federated_language.framework.create_federated_zip(unzipped) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, clients_list_type ) - type_test_utils.assert_types_identical(after.type_signature, list_type) + federated_language.framework.assert_types_identical( + after.type_signature, list_type + ) def test_strip_placement_removes_federated_value_at_server(self): - int_data = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + int_data = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + float_data = federated_language.framework.Literal( + 2.0, federated_language.TensorType(np.float32) ) - float_data = building_blocks.Literal( - 2.0, computation_types.TensorType(np.float32) + fed_int = federated_language.framework.create_federated_value( + int_data, federated_language.SERVER ) - fed_int = building_block_factory.create_federated_value( - int_data, placements.SERVER + fed_float = federated_language.framework.create_federated_value( + float_data, federated_language.SERVER ) - fed_float = building_block_factory.create_federated_value( - float_data, placements.SERVER + tup = federated_language.framework.Struct( + [fed_int, fed_float], container_type=tuple ) - tup = building_blocks.Struct([fed_int, fed_float], container_type=tuple) - before = building_block_factory.create_federated_zip(tup) + before = federated_language.framework.create_federated_zip(tup) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - tuple_type = computation_types.StructWithPythonType( + tuple_type = federated_language.StructWithPythonType( [(None, np.int32), (None, np.float32)], tuple ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, - computation_types.FederatedType(tuple_type, placements.SERVER), + federated_language.FederatedType(tuple_type, federated_language.SERVER), + ) + federated_language.framework.assert_types_identical( + after.type_signature, tuple_type ) - type_test_utils.assert_types_identical(after.type_signature, tuple_type) def test_strip_placement_federated_value_at_clients(self): - int_data = building_blocks.Literal( - 1, computation_types.TensorType(np.int32) + int_data = federated_language.framework.Literal( + 1, federated_language.TensorType(np.int32) + ) + float_data = federated_language.framework.Literal( + 2.0, federated_language.TensorType(np.float32) ) - float_data = building_blocks.Literal( - 2.0, computation_types.TensorType(np.float32) + fed_int = federated_language.framework.create_federated_value( + int_data, federated_language.CLIENTS ) - fed_int = building_block_factory.create_federated_value( - int_data, placements.CLIENTS + fed_float = federated_language.framework.create_federated_value( + float_data, federated_language.CLIENTS ) - fed_float = building_block_factory.create_federated_value( - float_data, placements.CLIENTS + tup = federated_language.framework.Struct( + [fed_int, fed_float], container_type=tuple ) - tup = building_blocks.Struct([fed_int, fed_float], container_type=tuple) - before = building_block_factory.create_federated_zip(tup) + before = federated_language.framework.create_federated_zip(tup) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - tuple_type = computation_types.StructWithPythonType( + tuple_type = federated_language.StructWithPythonType( [(None, np.int32), (None, np.float32)], tuple ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, - computation_types.FederatedType(tuple_type, placements.CLIENTS), + federated_language.FederatedType( + tuple_type, federated_language.CLIENTS + ), + ) + federated_language.framework.assert_types_identical( + after.type_signature, tuple_type ) - type_test_utils.assert_types_identical(after.type_signature, tuple_type) def test_strip_placement_with_called_lambda(self): - int_type = computation_types.TensorType(np.int32) - server_int_type = computation_types.FederatedType( - int_type, placements.SERVER + int_type = federated_language.TensorType(np.int32) + server_int_type = federated_language.FederatedType( + int_type, federated_language.SERVER + ) + federated_ref = federated_language.framework.Reference( + 'outer', server_int_type ) - federated_ref = building_blocks.Reference('outer', server_int_type) - inner_federated_ref = building_blocks.Reference('inner', server_int_type) - identity_lambda = building_blocks.Lambda( + inner_federated_ref = federated_language.framework.Reference( + 'inner', server_int_type + ) + identity_lambda = federated_language.framework.Lambda( 'inner', server_int_type, inner_federated_ref ) - before = building_blocks.Call(identity_lambda, federated_ref) + before = federated_language.framework.Call(identity_lambda, federated_ref) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, server_int_type ) - type_test_utils.assert_types_identical(after.type_signature, int_type) + federated_language.framework.assert_types_identical( + after.type_signature, int_type + ) def test_strip_placement_nested_federated_type(self): - int_type = computation_types.TensorType(np.int32) - server_int_type = computation_types.FederatedType( - int_type, placements.SERVER + int_type = federated_language.TensorType(np.int32) + server_int_type = federated_language.FederatedType( + int_type, federated_language.SERVER ) - tupled_int_type = computation_types.StructType([int_type, int_type]) - tupled_server_int_type = computation_types.StructType([ + tupled_int_type = federated_language.StructType([int_type, int_type]) + tupled_server_int_type = federated_language.StructType([ server_int_type, server_int_type, ]) - fed_ref = building_blocks.Reference('x', server_int_type) - before = building_blocks.Struct([fed_ref, fed_ref]) + fed_ref = federated_language.framework.Reference('x', server_int_type) + before = federated_language.framework.Struct([fed_ref, fed_ref]) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( before.type_signature, tupled_server_int_type ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( after.type_signature, tupled_int_type ) diff --git a/tensorflow_federated/python/core/impl/computation/BUILD b/tensorflow_federated/python/core/impl/computation/BUILD deleted file mode 100644 index 978b9db49e..0000000000 --- a/tensorflow_federated/python/core/impl/computation/BUILD +++ /dev/null @@ -1,169 +0,0 @@ -load("@rules_python//python:defs.bzl", "py_library", "py_test") - -package( - default_applicable_licenses = ["//:package_license"], - default_visibility = [ - ":computation_packages", - "//tensorflow_federated/python/core/impl:impl_users", - "//tensorflow_federated/python/core/impl/execution_contexts:execution_contexts_packages", - "//tensorflow_federated/python/core/impl/executors:executors_packages", - "//tensorflow_federated/python/core/impl/federated_context:federated_context_packages", - ], -) - -package_group( - name = "computation_packages", - packages = ["//tensorflow_federated/python/core/impl/computation/..."], -) - -licenses(["notice"]) - -py_library( - name = "computation", - srcs = ["__init__.py"], - visibility = ["//tools/python_package:python_package_tool"], -) - -py_library( - name = "computation_base", - srcs = ["computation_base.py"], - deps = [ - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:typed_object", - ], -) - -py_library( - name = "computation_impl", - srcs = ["computation_impl.py"], - deps = [ - ":computation_base", - ":function_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", - ], -) - -py_test( - name = "computation_impl_test", - size = "small", - srcs = ["computation_impl_test.py"], - deps = [ - ":computation_impl", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:type_test_utils", - ], -) - -py_library( - name = "computation_serialization", - srcs = ["computation_serialization.py"], - deps = [ - ":computation_base", - ":computation_impl", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - ], -) - -py_test( - name = "computation_serialization_test", - size = "small", - srcs = ["computation_serialization_test.py"], - deps = [ - ":computation_base", - ":computation_impl", - ":computation_serialization", - "//tensorflow_federated/python/core/impl/compiler:computation_factory", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_library( - name = "computation_wrapper", - srcs = ["computation_wrapper.py"], - deps = [ - ":computation_base", - ":computation_impl", - ":polymorphic_computation", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - ], -) - -py_test( - name = "computation_wrapper_test", - size = "small", - srcs = ["computation_wrapper_test.py"], - deps = [ - ":computation_impl", - ":computation_wrapper", - ":function_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", - ], -) - -py_library( - name = "function_utils", - srcs = ["function_utils.py"], - deps = [ - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:typed_object", - ], -) - -py_test( - name = "function_utils_test", - size = "small", - srcs = ["function_utils_test.py"], - deps = [ - ":function_utils", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_library( - name = "polymorphic_computation", - srcs = ["polymorphic_computation.py"], - deps = [ - ":computation_impl", - ":function_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_test( - name = "polymorphic_computation_test", - srcs = ["polymorphic_computation_test.py"], - deps = [ - ":computation_impl", - ":polymorphic_computation", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:type_serialization", - ], -) diff --git a/tensorflow_federated/python/core/impl/computation/__init__.py b/tensorflow_federated/python/core/impl/computation/__init__.py deleted file mode 100644 index a024323dfb..0000000000 --- a/tensorflow_federated/python/core/impl/computation/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Libraries for interacting with a computation.""" diff --git a/tensorflow_federated/python/core/impl/computation/computation_base.py b/tensorflow_federated/python/core/impl/computation/computation_base.py deleted file mode 100644 index a4603cea3c..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_base.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines the abstract interface for classes that represent computations.""" - -import abc - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import typed_object - - -class Computation(typed_object.TypedObject, metaclass=abc.ABCMeta): - """An abstract interface for all classes that represent computations.""" - - @property - @abc.abstractmethod - def type_signature(self) -> computation_types.FunctionType: - """Returns the TFF type of this object.""" - raise NotImplementedError - - @abc.abstractmethod - def __call__(self, *args, **kwargs): - """Invokes the computation with the given arguments in the given context. - - Args: - *args: The positional arguments. - **kwargs: The keyword-based arguments. - - Returns: - The result of invoking the computation, the exact form of which depends - on the context. - """ - raise NotImplementedError - - @abc.abstractmethod - def __hash__(self) -> int: - """Hashes the computation. - - TFF backends reserve the right to compile instances of `tff.Computation`, - as they may need different representations or data structures altogether. - As these backends need to be able to cache the result of compilation, we - require that `tff.Computation` subclasses be hashable. - - Returns: - Integer representing the hash value of the `tff.Computation`. - """ - raise NotImplementedError diff --git a/tensorflow_federated/python/core/impl/computation/computation_impl.py b/tensorflow_federated/python/core/impl/computation/computation_impl.py deleted file mode 100644 index 8bf1ae3684..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_impl.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines the implementation of the base Computation interface.""" - -from typing import Optional - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization - - -class ConcreteComputation(computation_base.Computation): - """A representation of a `pb.Computation` in the `tff.Computation` interface. - - This implementation exposes methods to retrieve the backing `pb.Computation`, - as well as the Python representation of this protocol buffer represented by - an instance of `building_blocks.ComputationBuildingBlock`. - """ - - @classmethod - def get_proto(cls, value: 'ConcreteComputation') -> pb.Computation: - py_typecheck.check_type(value, cls) - return value._computation_proto # pylint: disable=protected-access - - @classmethod - def with_type( - cls, - value: 'ConcreteComputation', - type_spec: computation_types.FunctionType, - ) -> 'ConcreteComputation': - py_typecheck.check_type(value, cls) - py_typecheck.check_type(type_spec, computation_types.Type) - # Ensure we are assigning a type-safe signature. - value.type_signature.check_assignable_from(type_spec) - # pylint: disable=protected-access - return cls( - computation_proto=value._computation_proto, - context_stack=value._context_stack, - annotated_type=type_spec, - ) - # pylint: enable=protected-access - - @classmethod - def from_building_block( - cls, building_block: building_blocks.ComputationBuildingBlock - ) -> 'ConcreteComputation': - """Converts a computation building block to a computation impl.""" - py_typecheck.check_type( - building_block, building_blocks.ComputationBuildingBlock - ) - return cls( - computation_proto=building_block.proto, - context_stack=context_stack_impl.context_stack, - annotated_type=building_block.type_signature, # pytype: disable=wrong-arg-types - ) - - def to_building_block(self): - # TODO: b/161560999 - currently destroys annotated type. - # This should perhaps be fixed by adding `type_parameter` to `from_proto`. - return building_blocks.ComputationBuildingBlock.from_proto( - self._computation_proto - ) - - def to_compiled_building_block(self): - return building_blocks.CompiledComputation( - self._computation_proto, type_signature=self.type_signature - ) - - def __init__( - self, - *, - computation_proto: pb.Computation, - context_stack: context_stack_base.ContextStack, - annotated_type: Optional[computation_types.FunctionType] = None, - ): - """Constructs a new instance of ConcreteComputation from the computation_proto. - - Args: - computation_proto: The protocol buffer that represents the computation, an - instance of pb.Computation. - context_stack: The context stack to use. - annotated_type: Optional, type information with additional annotations - that replaces the information in `computation_proto.type`. - - Raises: - TypeError: If `annotated_type` is not `None` and is not compatible with - `computation_proto.type`. - ValueError: If `computation_proto.type` is `None`. - """ - py_typecheck.check_type(computation_proto, pb.Computation) - py_typecheck.check_type(context_stack, context_stack_base.ContextStack) - if computation_proto.type is None: - raise ValueError('Expected `computation_proto.type` to not be `None`.') - type_spec = type_serialization.deserialize_type(computation_proto.type) - - if annotated_type is not None: - if type_spec is None or not type_spec.is_assignable_from(annotated_type): - raise TypeError( - 'annotated_type not compatible with computation_proto.type\n' - f'computation_proto.type: {type_spec}\n' - f'annotated_type: {annotated_type}' - ) - type_spec = annotated_type - - if not isinstance(type_spec, computation_types.FunctionType): - raise TypeError( - f'{type_spec} is not a functional type, from proto: ' - f'{computation_proto}' - ) - - self._type_signature = type_spec - self._context_stack = context_stack - self._computation_proto = computation_proto - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, ConcreteComputation): - return NotImplemented - return self._computation_proto == other._computation_proto - - @property - def type_signature(self) -> computation_types.FunctionType: - return self._type_signature - - def __call__(self, *args, **kwargs): - arg = function_utils.pack_args(self._type_signature.parameter, args, kwargs) - result = self._context_stack.current.invoke(self, arg) - return result - - def __hash__(self) -> int: - return hash(self._computation_proto.SerializeToString(deterministic=True)) diff --git a/tensorflow_federated/python/core/impl/computation/computation_impl_test.py b/tensorflow_federated/python/core/impl/computation/computation_impl_test.py deleted file mode 100644 index 4081b11eab..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_impl_test.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import type_test_utils - - -class ConcreteComputationTest(absltest.TestCase): - - def test_something(self): - # TODO: b/113112108 - Revise these tests after a more complete - # implementation is in place. - - # At the moment, this should succeed, as both the computation body and the - # type are well-formed. - computation_impl.ConcreteComputation( - computation_proto=pb.Computation(**{ - 'type': type_serialization.serialize_type( - computation_types.FunctionType(np.int32, np.int32) - ), - 'intrinsic': pb.Intrinsic(uri='whatever'), - }), - context_stack=context_stack_impl.context_stack, - ) - - # This should fail, as the proto is not well-formed. - with self.assertRaises(NotImplementedError): - computation_impl.ConcreteComputation( - computation_proto=pb.Computation(), - context_stack=context_stack_impl.context_stack, - ) - - # This should fail, as "10" is not an instance of pb.Computation. - with self.assertRaises(TypeError): - computation_impl.ConcreteComputation( - computation_proto=10, - context_stack=context_stack_impl.context_stack, - ) - - def test_with_type_preserves_python_container(self): - struct_return_type = computation_types.FunctionType( - np.int32, computation_types.StructType([(None, np.int32)]) - ) - original_comp = computation_impl.ConcreteComputation( - computation_proto=pb.Computation(**{ - 'type': type_serialization.serialize_type(struct_return_type), - 'intrinsic': pb.Intrinsic(uri='whatever'), - }), - context_stack=context_stack_impl.context_stack, - ) - - list_return_type = computation_types.FunctionType( - np.int32, - computation_types.StructWithPythonType([(None, np.int32)], list), - ) - fn_with_annotated_type = computation_impl.ConcreteComputation.with_type( - original_comp, list_return_type - ) - type_test_utils.assert_types_identical( - list_return_type, fn_with_annotated_type.type_signature - ) - - def test_with_type_raises_non_assignable_type(self): - int_return_type = computation_types.FunctionType(np.int32, np.int32) - original_comp = computation_impl.ConcreteComputation( - computation_proto=pb.Computation(**{ - 'type': type_serialization.serialize_type(int_return_type), - 'intrinsic': pb.Intrinsic(uri='whatever'), - }), - context_stack=context_stack_impl.context_stack, - ) - - list_return_type = computation_types.FunctionType( - np.int32, - computation_types.StructWithPythonType([(None, np.int32)], list), - ) - with self.assertRaises(computation_types.TypeNotAssignableError): - computation_impl.ConcreteComputation.with_type( - original_comp, list_return_type - ) - - -class FromBuildingBlockTest(absltest.TestCase): - - def test_raises_on_none(self): - with self.assertRaises(TypeError): - computation_impl.ConcreteComputation.from_building_block(None) - - def test_converts_building_block_to_computation(self): - buiding_block = building_blocks.Lambda( - 'x', np.int32, building_blocks.Reference('x', np.int32) - ) - computation = computation_impl.ConcreteComputation.from_building_block( - buiding_block - ) - self.assertIsInstance(computation, computation_impl.ConcreteComputation) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/computation/computation_serialization.py b/tensorflow_federated/python/core/impl/computation/computation_serialization.py deleted file mode 100644 index 9b664f366e..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_serialization.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for serializing and deserializing TFF computations.""" - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl - - -def serialize_computation( - computation: computation_base.Computation, -) -> pb.Computation: - """Serializes 'tff.Computation' as a pb.Computation. - - Note: Currently only serialization for computation_impl.ConcreteComputation is - implemented. - - Args: - computation: An instance of `tff.Computation`. - - Returns: - The corresponding instance of `pb.Computation`. - - Raises: - TypeError: If the argument is of the wrong type. - NotImplementedError: for computation variants for which serialization is not - implemented. - """ - py_typecheck.check_type(computation, computation_base.Computation) - - if isinstance(computation, computation_impl.ConcreteComputation): - computation_proto = pb.Computation() - computation_proto.CopyFrom( - computation_impl.ConcreteComputation.get_proto(computation) - ) - return computation_proto - else: - raise NotImplementedError( - 'Serialization of type {} is not currentlyimplemented yet.'.format( - type(computation) - ) - ) - - -def deserialize_computation( - computation_proto: pb.Computation, -) -> computation_base.Computation: - """Deserializes 'tff.Computation' as a pb.Computation. - - Args: - computation_proto: An instance of `pb.Computation`. - - Returns: - The corresponding instance of `tff.Computation`. - - Raises: - TypeError: If the argument is of the wrong type. - """ - py_typecheck.check_type(computation_proto, pb.Computation) - return computation_impl.ConcreteComputation( - computation_proto=computation_proto, - context_stack=context_stack_impl.context_stack, - ) diff --git a/tensorflow_federated/python/core/impl/computation/computation_serialization_test.py b/tensorflow_federated/python/core/impl/computation/computation_serialization_test.py deleted file mode 100644 index 075a87be5f..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_serialization_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_serialization -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types - - -class ComputationSerializationTest(absltest.TestCase): - - def test_serialize_deserialize_round_trip(self): - type_spec = computation_types.TensorType(np.int32) - proto = computation_factory.create_lambda_identity(type_spec) - comp = computation_impl.ConcreteComputation( - computation_proto=proto, - context_stack=context_stack_impl.context_stack, - ) - serialized_comp = computation_serialization.serialize_computation(comp) - deserialize_comp = computation_serialization.deserialize_computation( - serialized_comp - ) - self.assertIsInstance(deserialize_comp, computation_base.Computation) - self.assertEqual(deserialize_comp, comp) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/computation/computation_wrapper.py b/tensorflow_federated/python/core/impl/computation/computation_wrapper.py deleted file mode 100644 index 283b93454d..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_wrapper.py +++ /dev/null @@ -1,514 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for constructing decorators for functions and `tf.function`s.""" - -import collections -from collections.abc import Callable, Iterable -import inspect -from typing import Optional - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import polymorphic_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions - - -def _parameters(fn): - return inspect.signature(fn).parameters.values() - - -def _check_parameters(parameters): - """Ensure only non-varargs positional-or-keyword arguments.""" - for parameter in parameters: - if parameter.default is not inspect.Parameter.empty: - # We don't have a way to build defaults into the function's type. - raise TypeError( - 'TFF does not support default parameters. Found parameter ' - f'`{parameter.name}` with default value {parameter.default}' - ) - if parameter.kind is inspect.Parameter.POSITIONAL_ONLY: - # We don't have a way to encode positional-only into the function's type. - raise TypeError( - 'TFF does not support positional-only parameters. Found parameter ' - f'`{parameter.name}` which appears before a `/` entry.' - ) - if parameter.kind is inspect.Parameter.KEYWORD_ONLY: - # We don't have a way to encode keyword-only into the function's type. - raise TypeError( - 'TFF does not support keyword-only arguments. Found parameter ' - f'`{parameter.name}` which appears after a `*` or `*args` entry.' - ) - if parameter.kind in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): - # For concrete functions, we can't determine at tracing time which - # arguments should be bundled into args vs. kwargs, since arguments can - # be passed by position *or* by keyword at later call sites. - raise TypeError( - 'TFF does not support varargs. Found varargs parameter ' - f'`{parameter.name}`.' - ) - if parameter.kind is not inspect.Parameter.POSITIONAL_OR_KEYWORD: - raise AssertionError(f'Unexpected parameter kind: {parameter.kind}') - - -def _wrap_polymorphic( - fn, wrapper_fn, infer_type_fn -) -> polymorphic_computation.PolymorphicComputation: - """Wraps `fn` in `wrapper_fn` at invocation time.""" - try: - name = fn.__name__ - except AttributeError: - name = None - - def _polymorphic_wrapper( - parameter_type: computation_types.Type, unpack: Optional[bool] - ): - return wrapper_fn(fn, parameter_type, unpack=unpack, name=name) - - return polymorphic_computation.PolymorphicComputation( - _polymorphic_wrapper, infer_type_fn - ) - - -def _wrap_concrete( - fn, wrapper_fn, parameter_type -) -> computation_impl.ConcreteComputation: - """Wraps `fn` in `wrapper_fn` given the provided `parameter_type`.""" - try: - name = fn.__name__ - except AttributeError: - name = None - - concrete_fn = wrapper_fn(fn, parameter_type, unpack=None, name=name) - py_typecheck.check_type( - concrete_fn, - computation_impl.ConcreteComputation, - 'value returned by the wrapper', - ) - result_parameter_type = concrete_fn.type_signature.parameter - if ( - result_parameter_type is not None - and not result_parameter_type.is_equivalent_to(parameter_type) - ): - raise TypeError( - 'Expected a concrete function that takes parameter {}, got one ' - 'that takes {}.'.format( - str(parameter_type), str(concrete_fn.type_signature.parameter) - ) - ) - return concrete_fn - - -def _parameter_type( - parameters, parameter_types: tuple[computation_types.Type, ...] -) -> Optional[computation_types.Type]: - """Bundle any user-provided parameter types into a single argument type.""" - parameter_names = [parameter.name for parameter in parameters] - if not parameter_types and not parameters: - return None - if len(parameter_types) == 1: - parameter_type = parameter_types[0] - if parameter_type is None and not parameters: - return None - if len(parameters) == 1: - return parameter_type - # There is a single parameter type but multiple parameters. - if not isinstance(parameter_type, computation_types.StructType) or len( - parameter_type - ) != len(parameters): - raise TypeError( - f'Function with {len(parameters)} parameters must have a parameter ' - 'type with the same number of parameters. Found parameter type ' - f'{parameter_type}.' - ) - name_list_from_types = [ - n for n, _ in parameter_type.items() if n is not None - ] - if name_list_from_types: - if len(name_list_from_types) != len(parameter_type): - raise TypeError( - 'Types with both named and unnamed fields cannot be unpacked into ' - f'argument lists. Found parameter type {parameter_type}.' - ) - if set(name_list_from_types) != set(parameter_names): - raise TypeError( - 'Function argument names must match field names of parameter type. ' - f'Found argument names {parameter_names}, which do not match ' - f'{name_list_from_types}, the top-level fields of the parameter ' - f'type {parameter_type}.' - ) - # The provided parameter type has all named fields which exactly match - # the names of the function's parameters. - return parameter_type - else: - # The provided parameter type has no named fields. Apply the names from - # the function parameters. - parameter_types = (v for (_, v) in parameter_type.items()) - return computation_types.StructWithPythonType( - list(zip(parameter_names, parameter_types)), collections.OrderedDict - ) - elif len(parameters) == 1: - # If there are multiple provided argument types but the function being - # decorated only accepts a single argument, tuple the arguments together. - return computation_types.StructType(parameter_types) - if len(parameters) != len(parameter_types): - raise TypeError( - f'Function with {len(parameters)} parameters is ' - f'incompatible with provided argument types {parameter_types}.' - ) - # The function has `n` parameters and `n` parameter types. - # Zip them up into a structure using the names from the function as keys. - return computation_types.StructWithPythonType( - list(zip(parameter_names, parameter_types)), collections.OrderedDict - ) - - -def _wrap( - fn, - wrapper_fn, - parameter_types: tuple[computation_types.Type, ...], - infer_type_fn: Callable[[object], computation_types.Type], -): - """Wraps a possibly-polymorphic `fn` in `wrapper_fn`. - - If `parameter_type` is `None` and `fn` takes any arguments (even with default - values), `fn` is inferred to be polymorphic and won't be passed to - `wrapper_fn` until invocation time (when concrete parameter types are - available). - - `wrapper_fn` must accept three positional arguments and one defaulted argument - `name`: - - * `target_fn`, the Python function to be wrapped. - - * `parameter_types`, the user-provded list of parameter types. - - * `unpack`, an argument which will be passed on to - `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`. - See that function for details. - - * Optional `name`, the name of the function that is being wrapped (only for - debugging purposes). - - Args: - fn: The function or tf.function to wrap as a computation. - wrapper_fn: The Python callable that performs actual wrapping. The object to - be returned by this function should be an instance of a - `tff.framework.ConcreteComputation`. - parameter_types: Types of any arguments to `fn`. - infer_type_fn: ... - - Returns: - Either the result of wrapping (an object that represents the computation), - or a polymorphic callable that performs wrapping upon invocation based on - argument types. The returned function still may accept multiple - arguments (it has not yet had - `function_uils.wrap_as_zero_or_one_arg_callable` applied to it). - """ - parameters = _parameters(fn) - # NOTE: many of the properties checked here are only necessary for - # non-polymorphic computations whose type signatures must be resolved - # prior to use. However, we continue to enforce these requirements even - # in the polymorphic case in order to avoid creating an inconsistency. - _check_parameters(parameters) - - if parameters and not parameter_types: - # There is no TFF type specification, and the function/tf.function declares - # parameters. Create a polymorphic template. - wrapped_fn = _wrap_polymorphic(fn, wrapper_fn, infer_type_fn) - else: - # Either we have a concrete parameter type, or this is no-arg function. - parameter_type = _parameter_type(parameters, parameter_types) - wrapped_fn = _wrap_concrete(fn, wrapper_fn, parameter_type) - - # When applying a decorator, the __doc__ attribute with the documentation - # in triple-quotes is not automatically transferred from the function on - wrapped_fn.__doc__ = getattr(fn, '__doc__', None) - return wrapped_fn - - -def _is_function(obj): - # TFF supports passing type specifications (i.e. objects that can be turned - # into a `tff.Type`) as arguments to a computation decorator. In some cases - # those type specifications (e.g. np.int32) are a `type`, making them - # `Callable`, but they should not be treated as the function being decorated. - if isinstance(obj, type): - return False - return isinstance(obj, Callable) - - -class ComputationReturnedNoneError(ValueError): - """Error for computations which return `None` or do not return.""" - - def __init__(self, fn): - code = fn.__code__ - line_number = code.co_firstlineno - filename = code.co_filename - message = ( - f'The function defined on line {line_number} of file {filename} ' - "returned `None` (or didn't explicitly `return` at all), but TFF " - 'computations must return some non-`None` value.' - ) - super().__init__(message) - - -class ComputationWrapper: - """A class for creating wrappers that convert functions into computations. - - This class builds upon the _wrap() function defined above, adding on - functionality shared between the computation and `federated_computation` - decorators. The shared functionality includes relating formal Python function - parameters and call arguments to TFF types, packing and unpacking arguments, - verifying types, and support for polymorphism. - - Here's how one can use `ComputationWrapper` to construct a decorator/wrapper - named `xyz`: - - ```python - def my_wrapper_fn(target_fn, parameter_type, unpack, name=None): - ... - xyz = computation_wrapper.ComputationWrapper(my_wrapper_fn) - ``` - - The resulting `xyz` can then be used either as an `@xyz(...)` decorator or as - a manual wrapping function: `wrapped_fn = xyz(my_func, ...)`. The latter - method may be preferable when using functions from an external module or - for wrapping an anonymous lambda. - - The decorator can be used in two ways: - 1. Invoked with a single positional argument specifying the types of the - function's arguments (`@xyz(some_argument_type)`). - 2. Invoked with no arguments (`@xyz` or `@xyz()`). This is used for functions - which take no arguments, or functions which are polymorphic (used with - multiple different argument types). - - Here's how the decorator behaves in each case: - - If the user specifies a tuple type in an unbundled form (simply by listing the - types of its constituents as separate arguments), the tuple type is formed on - the user's behalf for convenience. - - 1. When the decorator is invoked with positional arguments: - - ```python - @xyz(('x', np.int32), ('y', np.int32)) - ``` - - The decorator arguments must be instances of `types.Type`, or something - convertible to it by `types.to_type()`. The arguments are interpreted as - the specification of the parameter of the computation being constructed by - the decorator. Since the formal parameter to computations is always a - single argument, multiple arguments to the decorator will be packed into a - tuple type. This means that the following two invocations behave the same: - - ``` - @xyz(('x', np.int32), ('y', np.int32)) # gets packed into the below - @xyz((('x', np.int32), ('y', np.int32))) - ``` - - In the above example, the computation will accept as an argument a pair - of integers named `x` and `y`. - - The function being decorated this way must declare at least one parameter. - - a. If the Python function declares only one parameter, that parameter will - receive all arguments packed into a single value: - - ```python - @xyz(('x', np.int32), ('y', np.int32)) - def my_comp(coord): - ... # use `coord.x` and `coord.y` - ``` - - b. If the Python function declares multiple parameters, the computation's - parameter type must be convertible to type `tff.StructType` - (usually a list containing types or pairs of `(str, types.Type)`. - - ```python - # With explicitly named parameters - @xyz(('x', np.int32), ('y', np.int32)) - def my_comp(x, y): - ... # use `x` and `y` - - # Without explicitly named parameters - @xyz(np.int32, np.int32) - def my_comp(x, y): - ... # use `x` and `y` - ``` - - The number and order of parameters in the decorator arguments and the - Python function must match. For named elements, the names in the - decorator and the Python function must also match. - - 2. When the decorator is specified without arguments (`@xyz` or `@xyz()`): - - a. If the Python function declares no parameters, the decorator constructs - a no-parameter computation, as in the following example: - - ```python - @xyz - def my_comp(): - ... - ``` - - b. If the function does declare at least one parameter, it is treated as a - polymorphic function that's instantiated in each concrete context in - which it's used based on the types of its arguments. The decorator still - handles the plumbing and parameter type inference. - - For example: - - ```python - @xyz - def my_comp(x, y): - ... - ``` - - In this case, `my_comp` becomes a polymorphic callable, with the actual - construction postponed. Suppose it's then used as follows, e.g., in an - orchestration context: - - ```python - my_comp(5.0, True) - ``` - - At the time of invocation, the decorator uses the information contained - in the call arguments 5.0 and True to infer the computation's parameter - type signature, and once the types have been determined, proceeds in - exactly the same manner as already described in (1) above. - - It is important to note that the arguments of the invocation are not - simply passed into the body of the Python function being decorated. - The parameter type inference step is all that differs between the - polymorphic case and case (1) above. - - Polymorphic functions are the only case where no constraints exist on - the kinds of arguments that may be present: declaring default values, - `*args` or `**kwargs`, and any combination of those are valid. The - mapping is resolved at invocation time based on arguments of the call, - as in the example below: - - ```python - @xyz - def my_comp(x, y=True, *args, **kwargs): - ... - - my_comp(1, False, 2, 3, 'foo', name='bar') - ``` - - As with all polymorphic functions, no construction is actually performed - until invocation, and at invocation time, the default parameter values - are used alongside those actually used during the invocation to infer - the computation's parameter type. The code that actually constructs the - computation is oblivious to whether parameters of the Python function - being decorated were driven by the default values, or by the arguments - of the actual call. - - Note that the last argument to the function in the example above will - be inferred as type `('name', str)`, not just `str`. - - For more examples of usage, see `computation_wrapper_test`. - """ - - def __init__( - self, - wrapper_fn: Callable[..., computation_base.Computation], - to_type_fn: Callable[ - [object], computation_types.Type - ] = computation_types.to_type, - infer_type_fn: Callable[ - [object], computation_types.Type - ] = type_conversions.infer_type, - ): - """Construct a new wrapper/decorator for the given wrapper callable. - - Args: - wrapper_fn: The Python callable that performs actual wrapping (as in the - specification of `_wrap`). - to_type_fn: A `Callable` used to convert a backend-specific type - specification to a `tff.Type`. - infer_type_fn: A `Callable` used to convert a backend-specific values to a - `tff.Type`. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - self._wrapper_fn = wrapper_fn - self._to_type_fn = to_type_fn - self._infer_type_fn = infer_type_fn - - def __call__(self, *args, **kwargs): - """Handles the different modes of usage of the decorator/wrapper. - - This method only acts as a frontend that allows this class to be used as a - decorator or wrapper in a variety of ways. The actual wrapping is performed - by the private method `_wrap`. - - Args: - *args: Positional arguments for the computation decorator. - **kwargs: Keyword arguments passed to individual computation strategies. - - Returns: - Either a result of wrapping, or a callable that expects a function, - method, or a tf.function and performs wrapping on it, depending on - specific usage pattern. - - Raises: - TypeError: if the arguments are of the wrong types. - """ - - def _to_types( - objs: Iterable[object], - ) -> tuple[Optional[computation_types.Type], ...]: - result = [] - for obj in objs: - if obj is not None: - result.append(self._to_type_fn(obj)) - else: - result.append(None) - return tuple(result) - - if not args: - # If invoked as a decorator, and with an empty argument list as "@xyz()" - # applied to a function definition, expect the Python function being - # decorated to be passed in the subsequent call, and potentially create - # a polymorphic callable. The parameter type is unspecified. - # Deliberate wrapping with a lambda to prevent the caller from being able - # to accidentally specify parameter type as a second argument. - # The tricky partial recursion is needed to inline the logic in the - # "success" case below. - provided_types = [] - return lambda fn: _wrap( - fn, self._wrapper_fn, provided_types, self._infer_type_fn - ) - elif _is_function(args[0]): - # If the first argument on the list is a Python function, instance method, - # or a tf.function, this is the one that's being wrapped. This is the case - # of either a decorator invocation without arguments as "@xyz" applied to - # a function definition, of an inline invocation as - # `... = xyz(lambda....).` - # Any of the following arguments, if present, are the arguments to the - # wrapper that are to be interpreted as the type specification. - fn = args[0] - provided_types = _to_types(args[1:]) - return _wrap(fn, self._wrapper_fn, provided_types, self._infer_type_fn) - else: - provided_types = _to_types(args) - return lambda fn: _wrap( - fn, self._wrapper_fn, provided_types, self._infer_type_fn - ) diff --git a/tensorflow_federated/python/core/impl/computation/computation_wrapper_test.py b/tensorflow_federated/python/core/impl/computation/computation_wrapper_test.py deleted file mode 100644 index d7a0ed2c09..0000000000 --- a/tensorflow_federated/python/core/impl/computation/computation_wrapper_test.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization - - -class WrappedForTest(computation_impl.ConcreteComputation): - """A class that represents a wrapped function for testing purposes. - - Upon invocation, it returns a string of the form P : T -> R, where P is the - parameter tuple (or None), T is its type (or None), and R is the returned - result, all converted into strings via str(). - """ - - def __init__(self, fn, parameter_type, unpack, name=None): - del name # Unused. - fn_type = computation_types.FunctionType(parameter_type, np.str_) - test_proto = pb.Computation(type=type_serialization.serialize_type(fn_type)) - super().__init__( - computation_proto=test_proto, - context_stack=context_stack_impl.context_stack, - annotated_type=fn_type, - ) - - self._fn = function_utils.wrap_as_zero_or_one_arg_callable( - fn, parameter_type, unpack - ) - - @property - def fn(self): - return self._fn - - -class ContextForTest(context_base.SyncContext): - - def invoke(self, comp, arg): - result = comp.fn(arg) if comp.type_signature.parameter else comp.fn() - return '{} : {} -> {}'.format( - str(arg), str(comp.type_signature.parameter), str(result) - ) - - -test_wrap = computation_wrapper.ComputationWrapper(WrappedForTest) - - -class ComputationWrapperTest(absltest.TestCase): - - def test_as_decorator_without_arguments_on_no_parameter_py_fn(self): - @test_wrap - def my_fn(): - """This is my fn.""" - return 10 - - self.assertEqual(my_fn(), 'None : None -> 10') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_wrapper_without_arguments_on_no_parameter_lambda(self): - self.assertEqual(test_wrap(lambda: 10)(), 'None : None -> 10') - - def test_as_wrapper_without_arguments_on_no_parameter_partial(self): - def identity(x): - return x - - self.assertEqual( - test_wrap(functools.partial(identity, 10))(), 'None : None -> 10' - ) - - def test_as_decorator_with_empty_arguments_on_no_parameter_py_fn(self): - @test_wrap() - def my_fn(): - """This is my fn.""" - return 10 - - self.assertEqual(my_fn(), 'None : None -> 10') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_decorator_with_one_argument_on_no_parameter_py_fn(self): - with self.assertRaises(TypeError): - - @test_wrap(np.int32) - def _(): - pass - - def test_as_wrapper_with_one_argument_on_no_parameter_lambda(self): - with self.assertRaises(TypeError): - test_wrap(lambda: None, np.int32) - - def test_as_decorator_with_one_argument_on_one_parameter_py_fn(self): - - @test_wrap(np.int32) - def my_fn(x): - """This is my fn.""" - return x + 10 - - self.assertEqual(my_fn(5), '5 : int32 -> 15') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_wrapper_with_one_argument_on_one_parameter_lambda(self): - self.assertEqual( - test_wrap(lambda x: x + 10, np.int32)(5), '5 : int32 -> 15' - ) - - def test_as_decorator_with_non_tuple_argument_on_two_parameter_py_fn(self): - with self.assertRaises(TypeError): - - @test_wrap(np.int32) - def _(x, y): - del x, y # Unused. - pass - - def test_as_wrapper_with_non_tuple_argument_on_two_parameter_lambda(self): - with self.assertRaises(TypeError): - - def my_fn(x, y): - del x, y # Unused. - pass - - test_wrap(my_fn, np.int32) - - def test_as_decorator_with_two_tuple_argument_on_three_param_py_fn(self): - with self.assertRaises(TypeError): - - @test_wrap((np.int32, np.int32)) - def _(x, y, z): - del x, y, z # Unused. - pass - - def test_as_wrapper_with_two_tuple_argument_on_three_param_lambda(self): - with self.assertRaises(TypeError): - test_wrap(lambda x, y, z: None, (np.int32, np.int32)) - - def test_as_decorator_with_arg_name_mismatching_element_name_in_py_fn(self): - with self.assertRaises(TypeError): - - @test_wrap([('x', np.int32), ('y', np.int32)]) - def _(x, z): - del x, z # Unused. - pass - - def test_as_wrapper_with_arg_name_mismatching_element_name_in_lambda(self): - with self.assertRaises(TypeError): - test_wrap(lambda x, z: None, [('x', np.int32), ('y', np.int32)]) - - def test_as_decorator_with_tuple_params_on_two_parameter_py_fn(self): - - @test_wrap((np.int32, np.int32)) - def my_fn(x, y): - """This is my fn.""" - return x + y - - self.assertEqual(my_fn(1, 2), ' : -> 3') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_wrapper_with_tuple_params_on_two_parameter_py_fn(self): - wrapped = test_wrap(lambda x, y: x + y, (np.int32, np.int32)) - self.assertEqual(wrapped(1, 2), ' : -> 3') - - def test_as_decorator_with_tuple_params_on_one_parameter_py_fn(self): - # Computations only have a single parameter (or none), and we allow the - # flexibility of feeding tuple-like parameters in pieces by specifying - # tuple elementas as multiple arguments. This is independent of how the - # backing Python function binds to the argument on the definition side. - # Thus, the ordinary linter check is inapplicable, as there's exists no - # direct connection between the signature of the call and that of the - # Python definition. The TFF type decouples one from the other. - @test_wrap([('x', np.int32), ('y', np.int32)]) - def my_fn(arg): - """This is my fn.""" - return arg.x + arg.y - - self.assertEqual( - my_fn(1, 2), # pylint: disable=too-many-function-args - ' : -> 3', - ) - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_wrapper_with_tuple_params_on_one_parameter_py_fn(self): - self.assertEqual( - test_wrap(lambda arg: arg[0] + arg[1], (np.int32, np.int32))(1, 2), - '<1,2> : -> 3', - ) - - def test_as_decorator_with_named_tuple_params_on_two_param_py_fn(self): - - @test_wrap([('x', np.int32), ('y', np.int32)]) - def my_fn(x, y): - """This is my fn.""" - return x + y - - self.assertEqual(my_fn(1, 2), ' : -> 3') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_wrapper_with_named_tuple_params_on_two_param_py_fn(self): - wrapped = test_wrap(lambda x, y: x + y, [('x', np.int32), ('y', np.int32)]) - self.assertEqual(wrapped(1, 2), ' : -> 3') - - def test_as_decorator_without_arguments_on_py_fn_with_one_param(self): - @test_wrap - def my_fn(x): - """This is my fn.""" - return x + 1 - - self.assertEqual(my_fn(10), '<10> : -> 11') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_wrapper_without_arguments_on_py_fn_with_one_param(self): - wrapped = test_wrap(lambda x: x + 1) - self.assertEqual(wrapped(10), '<10> : -> 11') - - def test_as_decorator_without_arguments_on_py_fn_with_two_params(self): - @test_wrap - def my_fn(x, y): - """This is my fn.""" - return x + y - - self.assertEqual(my_fn(10, 20), '<10,20> : -> 30') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_decorator_with_empty_arguments_on_py_fn_with_one_param(self): - @test_wrap() - def my_fn(x): - """This is my fn.""" - return x + 1 - - self.assertEqual(my_fn(10), '<10> : -> 11') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_as_decorator_with_empty_arguments_on_py_fn_with_two_params(self): - @test_wrap() - def my_fn(x, y): - """This is my fn.""" - return x + y - - self.assertEqual(my_fn(10, 20), '<10,20> : -> 30') - self.assertEqual(my_fn.__doc__, 'This is my fn.') - - def test_with_integer_args(self): - with self.assertRaises(TypeError): - test_wrap(10, 20) - - def test_with_varargs_no_type(self): - with self.assertRaises(TypeError): - - @test_wrap - def _(*args): - """This is my fn.""" - return sum(args) - - def test_with_varargs_scalar_type(self): - with self.assertRaises(TypeError): - - @test_wrap(np.int32) - def _(*args): - """This is my fn.""" - return sum(args) - - def test_with_varargs_tuple_type(self): - with self.assertRaises(TypeError): - - @test_wrap([np.int32, np.int32, np.int32, np.int32]) - def _(x, y, *args): - """This is my fn.""" - return x + y + sum(args) - - def test_with_kwargs_no_type(self): - with self.assertRaises(TypeError): - - @test_wrap - def _(**kwargs): - """This is my fn.""" - return kwargs['x'] / kwargs['y'] - - def test_as_decorator_with_unbundled_arguments(self): - - @test_wrap(np.int32, np.int32) - def foo(unused_x, unused_y): - return 99 - - self.assertEqual( - foo(unused_y=20, unused_x=10), - ' : -> 99', - ) - - def test_as_decorator_with_named_positional_arguments(self): - - @test_wrap(np.int32, np.int32) - def foo(unused_x, unused_y): - return 99 - - expected = ( - ' : -> 99' - ) - self.assertEqual(foo(unused_x=10, unused_y=20), expected) - self.assertEqual(foo(10, unused_y=20), expected) - self.assertEqual(foo(unused_y=20, unused_x=10), expected) - - def test_as_decorator_with_optional_arguments(self): - with self.assertRaisesRegex(TypeError, 'default'): - - @test_wrap(np.int32, np.int32) - def _(unused_x=10, unused_y=20): - return 99 - - def test_as_wrapper_with_unbundled_arguments(self): - foo = test_wrap(lambda unused_x, unused_y: 99, np.int32, np.int32) - self.assertEqual( - foo(10, 20), - ' : -> 99', - ) - - def test_as_wrapper_with_one_argument_instance_method(self): - class IntWrapper: - - def __init__(self, x): - self._x = x - - def multiply_by(self, y): - return self._x * y - - five = IntWrapper(5) - wrapped = test_wrap(five.multiply_by, np.int32) - self.assertEqual(wrapped(2), '2 : int32 -> 10') - - def test_as_wrapper_with_no_argument_instance_method(self): - class C: - - def __init__(self, x): - self._x = x - - def my_method(self): - return self._x - - c = C(99) - wrapped = test_wrap(c.my_method) - self.assertEqual(wrapped(), 'None : None -> 99') - - def test_as_wrapper_with_class_property(self): - class C: - - @property - def x(self): - return 99 - - c = C() - with self.assertRaises(TypeError): - test_wrap(c.x) - - def test_as_wrapper_with_classmethod(self): - class C: - - @classmethod - def prefix(cls, msg): - return f'{cls.__name__}_{msg}' - - wrapped = test_wrap(C.prefix) - self.assertEqual(wrapped('foo'), ' : -> C_foo') - - -if __name__ == '__main__': - with context_stack_impl.context_stack.install(ContextForTest()): - absltest.main() diff --git a/tensorflow_federated/python/core/impl/computation/function_utils.py b/tensorflow_federated/python/core/impl/computation/function_utils.py deleted file mode 100644 index 96ca57251b..0000000000 --- a/tensorflow_federated/python/core/impl/computation/function_utils.py +++ /dev/null @@ -1,523 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for Python functions, defuns, and other types of callables.""" - -from collections.abc import Mapping, Sequence -import inspect -import types -from typing import Optional - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import typed_object - - -def is_signature_compatible_with_types( - signature: inspect.Signature, *args, **kwargs -) -> bool: - """Determines if functions matching signature accept `args` and `kwargs`. - - Args: - signature: An instance of `inspect.Signature` to verify agains the - arguments. - *args: Zero or more positional arguments, all of which must be instances of - computation_types.Type or something convertible to it by - computation_types.to_type(). - **kwargs: Zero or more keyword arguments, all of which must be instances of - computation_types.Type or something convertible to it by - computation_types.to_type(). - - Returns: - `True` or `False`, depending on the outcome of the test. - - Raises: - TypeError: if the arguments are of the wrong computation_types. - """ - try: - bound_args = signature.bind(*args, **kwargs) - except TypeError: - return False - - # If we have no defaults then `bind` will have raised `TypeError` if the - # signature was not compatible with *args and **kwargs. - if all( - p.default is inspect.Parameter.empty - for p in signature.parameters.values() - ): - return True - - # Otherwise we need to check the defaults against the types that were given to - # ensure they are compatible. - for p in signature.parameters.values(): - if p.default is inspect.Parameter.empty or p.default is None: - # No default value or optional. - continue - arg_value = bound_args.arguments.get(p.name, p.default) - if arg_value is p.default: - continue - arg_type = computation_types.to_type(arg_value) - default_type = type_conversions.infer_type(p.default) - if not arg_type.is_assignable_from(default_type): - return False - return True - - -def is_argument_struct(arg) -> bool: - """Determines if 'arg' is interpretable as an argument struct. - - Args: - arg: A value or type to test. - - Returns: - True iff 'arg' is either a `Struct` in which all unnamed elements - precede named ones, or a `StructType` with this property, or something - that can be converted into the latter by computation_types.to_type(). - - Raises: - TypeError: If the argument is neither an `structure.Struct`, - nor a type spec. - """ - if isinstance(arg, structure.Struct): - elements = structure.to_elements(arg) - elif isinstance(arg, typed_object.TypedObject): - return is_argument_struct(arg.type_signature) - else: - if arg is not None: - arg = computation_types.to_type(arg) - if isinstance(arg, computation_types.StructType): - elements = list(arg.items()) - else: - return False - max_unnamed = -1 - min_named = len(elements) - for idx, element in enumerate(elements): - if element[0]: - min_named = min(min_named, idx) - else: - max_unnamed = idx - return max_unnamed < min_named - - -def unpack_args_from_struct( - struct_with_args, -) -> tuple[list[computation_types.Type], dict[str, computation_types.Type]]: - """Extracts argument types from a struct. - - Args: - struct_with_args: An instance of either an `struct.Struct` or - computation_types.StructType` (or something convertible to it by - computation_types.to_type()), on which is_argument_struct() is True. - - Returns: - A pair (args, kwargs) containing tuple elements from 'struct_with_args'. - - Raises: - TypeError: if 'struct_with_args' is of a wrong type. - """ - if not is_argument_struct(struct_with_args): - raise TypeError('Not an argument struct: {}.'.format(struct_with_args)) - if isinstance(struct_with_args, structure.Struct): - elements = structure.to_elements(struct_with_args) - elif isinstance(struct_with_args, typed_object.TypedObject): - elements = [] - for index, (name, _) in enumerate(struct_with_args.type_signature.items()): # pytype: disable=attribute-error - if name is not None: - elements.append((name, getattr(struct_with_args, name))) - else: - elements.append((None, struct_with_args[index])) - else: - struct_with_args = computation_types.to_type(struct_with_args) - if not isinstance(struct_with_args, computation_types.StructType): - raise ValueError( - f'Expected a `tff.StructType`, found {struct_with_args}.' - ) - elements = list(struct_with_args.items()) - args = [] - kwargs = {} - for name, value in elements: - if name is not None: - kwargs[name] = value - else: - args.append(value) - return args, kwargs - - -def pack_args_into_struct( - args: Sequence[object], kwargs: Mapping[str, object], type_spec=None -) -> structure.Struct: - """Packs positional and keyword arguments into a `Struct`. - - If 'type_spec' is not None, it must be a `StructType` or something that's - convertible to it by computation_types.to_type(). The assignment of arguments - to fields of the struct follows the same rule as during function calls. If - 'type_spec' is None, the positional arguments precede any of the keyword - arguments, and the ordering of the keyword arguments matches the ordering in - which they appear in kwargs. If the latter is an OrderedDict, the ordering - will be preserved. On the other hand, if the latter is an ordinary unordered - dict, the ordering is arbitrary. - - Args: - args: Positional arguments. - kwargs: Keyword arguments. - type_spec: The optional type specification (either an instance of - `computation_types.StructType` or something convertible to it), or None if - there's no type. Used to drive the arrangements of args into fields of the - constructed struct, as noted in the description. - - Returns: - An struct containing all the arguments. - - Raises: - TypeError: if the arguments are of the wrong computation_types. - """ - if type_spec is not None: - type_spec = computation_types.to_type(type_spec) - if not type_spec: - return structure.Struct( - [(None, arg) for arg in args] + list(kwargs.items()) - ) - else: - py_typecheck.check_type(type_spec, computation_types.StructType) - if not is_argument_struct(type_spec): - raise TypeError( - 'Parameter type {} does not have a structure of an argument struct, ' - 'and cannot be populated from multiple positional and keyword ' - 'arguments'.format(type_spec) - ) - else: - result_elements = [] - positions_used = set() - keywords_used = set() - for index, (name, elem_type) in enumerate(type_spec.items()): # pytype: disable=attribute-error - if index < len(args): - # This argument is present in `args`. - if name is not None and name in kwargs: - raise TypeError('Argument `{}` specified twice.'.format(name)) - else: - arg_value = args[index] - result_elements.append((name, arg_value)) - positions_used.add(index) - elif name is not None and name in kwargs: - # This argument is present in `kwargs`. - arg_value = kwargs[name] - result_elements.append((name, arg_value)) - keywords_used.add(name) - elif name: - raise TypeError(f'Missing argument `{name}` of type {elem_type}.') - else: - raise TypeError( - f'Missing argument of type {elem_type} at position {index}.' - ) - positions_missing = set(range(len(args))).difference(positions_used) - if positions_missing: - raise TypeError( - f'Positional arguments at {positions_missing} not used.' - ) - keywords_missing = set(kwargs.keys()).difference(keywords_used) - if keywords_missing: - raise TypeError(f'Keyword arguments at {keywords_missing} not used.') - return structure.Struct(result_elements) - - -def pack_args( - parameter_type, args: Sequence[object], kwargs: Mapping[str, object] -): - """Pack arguments into a single one that matches the given parameter type. - - The arguments may or may not be packed into a `Struct`, depending on the type - of the parameter, and how many arguments are present. - - Args: - parameter_type: The type of the single parameter expected by a computation, - an instance of computation_types.Type or something convertible to it, or - None if the computation is not expecting a parameter. - args: Positional arguments of a call. - kwargs: Keyword arguments of a call. - - Returns: - A single value object of type that matches 'parameter_type' that contains - all the arguments, or None if the 'parameter_type' is None. - - Raises: - TypeError: if the args/kwargs do not match the given parameter type. - """ - if parameter_type is None: - # If there's no parameter type, there should be no args of any kind. - if args or kwargs: - raise TypeError('Was not expecting any arguments.') - else: - return None - parameter_type = computation_types.to_type(parameter_type) - if not args and not kwargs: - raise TypeError( - 'Declared a parameter of type {}, but got no arguments.'.format( - parameter_type - ) - ) - single_positional_arg = len(args) == 1 and not kwargs - if single_positional_arg: - return args[0] - if not isinstance(parameter_type, computation_types.StructType): - # If not a `StructType`, a single positional argument is the only - # supported call style. - raise TypeError( - 'Parameter type {} is compatible only with a single positional ' - 'argument, but found {} positional and {} keyword args.'.format( - parameter_type, len(args), len(kwargs) - ) - ) - if not is_argument_struct(parameter_type): - raise TypeError( - 'Parameter type {} does not have a structure of an argument ' - 'struct, and cannot be populated from multiple positional and ' - 'keyword arguments; please construct a struct before the ' - 'call.'.format(parameter_type) - ) - return pack_args_into_struct(args, kwargs, parameter_type) - - -def _infer_unpack_needed( - fn: types.FunctionType, - parameter_type: Optional[computation_types.Type], - should_unpack: Optional[bool] = None, -) -> bool: - """Returns whether parameter_type must be unpacked when calling fn. - - Args: - fn: The function to be invoked. - parameter_type: The TFF type of the parameter bundle to be accepted by the - returned callable, if any, or None if there's no parameter. - should_unpack: Default or expected return value; None implies the inferred - value should be returned. If either unpacking or packing could work, and - should_unpack is not None, then should_unpack is returned. - - Returns: - A `bool` indicating whether or not to unpack. - """ - # TODO: b/113112885 - Revisit whether the 3-way 'unpack' knob is sufficient - # for our needs, or more options are needed. - if should_unpack not in [True, False, None]: - raise TypeError( - 'The unpack argument has an unexpected value {!r}.'.format( - should_unpack - ) - ) - py_typecheck.check_type(parameter_type, computation_types.Type) - unpack = should_unpack # Default return value. - signature = inspect.signature(fn) - - if parameter_type is None: - if is_signature_compatible_with_types(signature): - if should_unpack: - raise ValueError('Requested unpacking of a no-arg function.') - return False - else: - raise TypeError( - 'The signature {} of the supplied function cannot be interpreted as ' - 'a body of a no-parameter computation.'.format(signature) - ) - - unpack_required = not is_signature_compatible_with_types( - signature, parameter_type - ) - if unpack_required and should_unpack is not None and not should_unpack: - raise TypeError( - "The supplied function '{}' with signature {} cannot accept a " - "value of type '{}' as a single argument.".format( - fn.__name__, signature, parameter_type - ) - ) - if is_argument_struct(parameter_type): - arg_types, kwarg_types = unpack_args_from_struct(parameter_type) - unpack_possible = is_signature_compatible_with_types( - signature, *arg_types, **kwarg_types - ) - else: - unpack_possible = False - if not unpack_possible and should_unpack is not None and should_unpack: - raise TypeError( - 'The supplied function with signature {} cannot accept a value of type ' - '{} as multiple positional and/or keyword arguments. That is, the ' - 'argument cannot be unpacked, but unpacking was requested.'.format( - signature, parameter_type - ) - ) - if unpack_required and not unpack_possible: - raise TypeError( - 'The supplied function "{}" with signature {} cannot accept a value of ' - 'type {} as either a single argument or multiple positional and/or ' - 'keyword arguments.'.format(fn.__name__, signature, parameter_type) - ) - if not unpack_required and unpack_possible and should_unpack is None: - # The supplied function could accept a value as either a single argument, - # or as multiple positional and/or keyword arguments, and the caller did - # not specify any preference, leaving ambiguity in how to handle the - # mapping. We resolve the ambiguity by defaulting to capturing the entire - # argument, as that's the behavior suggested as expected by the users. - unpack = False - - if unpack is None: - # Any ambiguity at this point has been resolved, so the following - # condition holds and need only be verified in tests. - assert unpack_required == unpack_possible, ( - unpack_required, - unpack_possible, - ) - unpack = unpack_possible - - return unpack - - -def wrap_as_zero_or_one_arg_callable( - fn: types.FunctionType, - parameter_type: Optional[computation_types.Type] = None, - unpack: Optional[bool] = None, -): - """Wraps around `fn` so it accepts up to one positional TFF-typed argument. - - This function helps to simplify dealing with functions and defuns that might - have diverse and complex signatures, but that represent computations and as - such, conceptually only accept a single parameter. The returned callable has - a single positional parameter or no parameters. If it has one parameter, the - parameter is expected to contain all arguments required by `fn` and matching - the supplied parameter type signature bundled together into a `Struct`, - if needed. The callable unpacks that structure, and passes all of - its elements as positional or keyword-based arguments in the call to `fn`. - - Example usage: - - @tf.function - def my_fn(x, y, z=10, name='bar', *p, **q): - return x + y - - type_spec = (np.int32, np.int32) - - wrapped_fn = wrap_as_zero_or_one_arg_callable(my_fn, type_spec) - - arg = Struct([('x', 10), ('y', 20)]) - - ... = wrapped_fn(arg) - - Args: - fn: The underlying backend function or defun to invoke with the unpacked - arguments. - parameter_type: The TFF type of the parameter bundle to be accepted by the - returned callable, if any, or None if there's no parameter. - unpack: Whether to break the parameter down into constituent parts and feed - them as arguments to `fn` (True), leave the parameter as is and pass it to - `fn` as a single unit (False), or allow it to be inferred from the - signature of `fn` (None). In the latter case (None), if any ambiguity - arises, an exception is thrown. If the parameter_type is None, this value - has no effect, and is simply ignored. - - Returns: - The zero- or one-argument callable that invokes `fn` with the unbundled - arguments, as described above. - - Raises: - TypeError: if arguments to this call are of the wrong types, or if the - supplied 'parameter_type' is not compatible with `fn`. - """ - # TODO: b/113112885 - Revisit whether the 3-way 'unpack' knob is sufficient - # for our needs, or more options are needed. - signature = inspect.signature(fn) - if parameter_type is None: - if is_signature_compatible_with_types(signature): - # Deliberate wrapping to isolate the caller from `fn`, e.g., to prevent - # the caller from mistakenly specifying args that match fn's defaults. - return lambda: fn() # pylint: disable=unnecessary-lambda - else: - raise TypeError( - 'The signature {} of the supplied function cannot be interpreted as ' - 'a body of a no-parameter computation.'.format(signature) - ) - else: - parameter_type = computation_types.to_type(parameter_type) - - def _call(fn, parameter_type, arg, unpack): - args, kwargs = unpack_arg(fn, parameter_type, arg, unpack) - return fn(*args, **kwargs) - - # TODO: b/132888123 - Consider other options to avoid possible bugs here. - try: - (fn, parameter_type, unpack) - except NameError as e: - raise AssertionError('Args to be bound must be in scope.') from e - return lambda arg: _call(fn, parameter_type, arg, unpack) - - -def _unpack_arg( - arg_types, kwarg_types, arg -) -> tuple[list[object], dict[str, object]]: - """Unpacks 'arg' into an argument list based on types.""" - args = [] - for idx, expected_type in enumerate(arg_types): - element_value = arg[idx] - if isinstance(element_value, structure.Struct): - element_value = type_conversions.type_to_py_container( - element_value, expected_type - ) - args.append(element_value) - kwargs = {} - for name, expected_type in kwarg_types.items(): - element_value = getattr(arg, name) - if type_analysis.is_struct_with_py_container(element_value, expected_type): - element_value = type_conversions.type_to_py_container( - element_value, expected_type - ) - kwargs[name] = element_value - return args, kwargs - - -def _ensure_arg_type( - parameter_type, arg -) -> tuple[list[object], dict[str, object]]: - """Ensures that `arg` matches `parameter_type` before returning it.""" - if type_analysis.is_struct_with_py_container(arg, parameter_type): - arg = type_conversions.type_to_py_container(arg, parameter_type) - return [arg], {} - - -def unpack_arg( - fn: types.FunctionType, - parameter_type: Optional[computation_types.Type], - arg, - unpack: Optional[bool] = None, -) -> tuple[list[object], dict[str, object]]: - """Converts TFF values into arguments to `fn`. - - Args: - fn: The function to unpack arguments for. - parameter_type: The TFF type of the parameter bundle to be accepted by the - returned callable. - arg: The argument to unpack. - unpack: Whether to break the parameter down into constituent parts (`True`), - leave the parameter as a single unit (False), or allow it to be inferred - from the signature of `fn` (None). In the latter case (None), if any - ambiguity arises, an exception is thrown. - - Returns: - The unpacked arg. - """ - if parameter_type is None: - return [], {} - - if _infer_unpack_needed(fn, parameter_type, unpack): - arg_types, kwarg_types = unpack_args_from_struct(parameter_type) - return _unpack_arg(arg_types, kwarg_types, arg) - else: - return _ensure_arg_type(parameter_type, arg) diff --git a/tensorflow_federated/python/core/impl/computation/function_utils_test.py b/tensorflow_federated/python/core/impl/computation/function_utils_test.py deleted file mode 100644 index 1f3f1756f8..0000000000 --- a/tensorflow_federated/python/core/impl/computation/function_utils_test.py +++ /dev/null @@ -1,342 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import inspect - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.types import computation_types - - -class FunctionUtilsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'args_only', - inspect.signature(lambda a: None), - [np.int32], - collections.OrderedDict(), - ), - ( - 'args_and_kwargs_unnamed', - inspect.signature(lambda a, b=True: None), - [np.int32, np.bool_], - collections.OrderedDict(), - ), - ( - 'args_and_kwargs_named', - inspect.signature(lambda a, b=True: None), - [np.int32], - collections.OrderedDict(b=np.bool_), - ), - ( - 'args_and_kwargs_default_int', - inspect.signature(lambda a=10, b=True: None), - [np.int32], - collections.OrderedDict(b=np.bool_), - ), - ) - def test_is_signature_compatible_with_types_true( - self, signature, *args, **kwargs - ): - self.assertFalse( - function_utils.is_signature_compatible_with_types( - signature, *args, **kwargs - ) - ) - - @parameterized.named_parameters( - ( - 'args_only', - inspect.signature(lambda a=True: None), - [np.int32], - collections.OrderedDict(), - ), - ( - 'args_and_kwargs', - inspect.signature(lambda a=10, b=True: None), - [np.bool_], - collections.OrderedDict(b=np.bool_), - ), - ) - def test_is_signature_compatible_with_types_false( - self, signature, *args, **kwargs - ): - self.assertFalse( - function_utils.is_signature_compatible_with_types( - signature, *args, **kwargs - ) - ) - - @parameterized.named_parameters( - ('int', np.int32, False), - ('tuple_unnamed', [np.int32, np.int32], True), - ('tuple_partially_named', [np.int32, ('b', np.int32)], True), - ('tuple_named', [('a', np.int32), ('b', np.int32)], True), - ( - 'tuple_partially_named_kwargs_first', - [('a', np.int32), np.int32], - False, - ), - ('struct', structure.Struct([(None, 1), ('a', 2)]), True), - ('struct_kwargs_first', structure.Struct([('a', 1), (None, 2)]), False), - ) - def test_is_argument_struct(self, arg, expected_result): - self.assertEqual(function_utils.is_argument_struct(arg), expected_result) - - @parameterized.named_parameters( - ('tuple_unnamed', structure.Struct([(None, 1)]), [1], {}), - ( - 'tuple_partially_named', - structure.Struct([(None, 1), ('a', 2)]), - [1], - {'a': 2}, - ), - ) - def test_unpack_args_from_structure( - self, tuple_with_args, expected_args, expected_kwargs - ): - self.assertEqual( - function_utils.unpack_args_from_struct(tuple_with_args), - (expected_args, expected_kwargs), - ) - - @parameterized.named_parameters( - ( - 'tuple_unnamed_1', - [np.int32], - [computation_types.TensorType(np.int32)], - {}, - ), - ( - 'tuple_named_1', - [('a', np.int32)], - [], - {'a': computation_types.TensorType(np.int32)}, - ), - ( - 'tuple_unnamed_2', - [np.int32, np.bool_], - [ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.bool_), - ], - {}, - ), - ( - 'tuple_partially_named', - [np.int32, ('b', np.bool_)], - [computation_types.TensorType(np.int32)], - {'b': computation_types.TensorType(np.bool_)}, - ), - ( - 'tuple_named_2', - [('a', np.int32), ('b', np.bool_)], - [], - { - 'a': computation_types.TensorType(np.int32), - 'b': computation_types.TensorType(np.bool_), - }, - ), - ) - def test_unpack_args_from_struct_type( - self, tuple_with_args, expected_args, expected_kwargs - ): - args, kwargs = function_utils.unpack_args_from_struct(tuple_with_args) - self.assertEqual(len(args), len(expected_args)) - for idx, arg in enumerate(args): - self.assertTrue(arg.is_equivalent_to(expected_args[idx])) - self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys())) - for k, v in kwargs.items(): - self.assertTrue(v.is_equivalent_to(expected_kwargs[k])) - - def test_pack_args_into_struct_without_type_spec(self): - self.assertEqual( - function_utils.pack_args_into_struct([1], {'a': 10}), - structure.Struct([(None, 1), ('a', 10)]), - ) - self.assertIn( - function_utils.pack_args_into_struct([1, 2], {'a': 10, 'b': 20}), - [ - structure.Struct([ - (None, 1), - (None, 2), - ('a', 10), - ('b', 20), - ]), - structure.Struct([ - (None, 1), - (None, 2), - ('b', 20), - ('a', 10), - ]), - ], - ) - self.assertIn( - function_utils.pack_args_into_struct([], {'a': 10, 'b': 20}), - [ - structure.Struct([('a', 10), ('b', 20)]), - structure.Struct([('b', 20), ('a', 10)]), - ], - ) - self.assertEqual( - function_utils.pack_args_into_struct([1], {}), - structure.Struct([(None, 1)]), - ) - - @parameterized.named_parameters( - ('int', [1], {}, [np.int32], [(None, 1)]), - ( - 'tuple_unnamed_with_args', - [1, True], - {}, - [np.int32, np.bool_], - [(None, 1), (None, True)], - ), - ( - 'tuple_named_with_args', - [1, True], - {}, - [('x', np.int32), ('y', np.bool_)], - [('x', 1), ('y', True)], - ), - ( - 'tuple_named_with_args_and_kwargs', - [1], - {'y': True}, - [('x', np.int32), ('y', np.bool_)], - [('x', 1), ('y', True)], - ), - ( - 'tuple_with_kwargs', - [], - {'x': 1, 'y': True}, - [('x', np.int32), ('y', np.bool_)], - [('x', 1), ('y', True)], - ), - ( - 'tuple_with_args_odict', - [], - collections.OrderedDict([('y', True), ('x', 1)]), - [('x', np.int32), ('y', np.bool_)], - [('x', 1), ('y', True)], - ), - ) - def test_pack_args_into_struct_with_type_spec_expect_success( - self, args, kwargs, type_spec, elements - ): - self.assertEqual( - function_utils.pack_args_into_struct(args, kwargs, type_spec), - structure.Struct(elements), - ) - - def test_pack_args_into_struct_named_to_unnamed_fails(self): - with self.assertRaises(TypeError): - function_utils.pack_args_into_struct( - [], {'x': 1, 'y': True}, [np.int32, np.bool_] - ) - - @parameterized.named_parameters( - ('none', None, [], {}, 'None'), - ('int', np.int32, [1], {}, '1'), - ('tuple_unnamed', [np.int32, np.bool_], [1, True], {}, '<1,True>'), - ( - 'tuple_named_with_args', - [('x', np.int32), ('y', np.bool_)], - [1, True], - {}, - '', - ), - ( - 'tuple_named_with_kwargs', - [('x', np.int32), ('y', np.bool_)], - [1], - {'y': True}, - '', - ), - ( - 'tuple_with_args_struct', - [np.int32, np.bool_], - [structure.Struct([(None, 1), (None, True)])], - {}, - '<1,True>', - ), - ) - def test_pack_args(self, parameter_type, args, kwargs, expected_value_string): - self.assertEqual( - str(function_utils.pack_args(parameter_type, args, kwargs)), - expected_value_string, - ) - - @parameterized.named_parameters( - ('const', lambda: 10, None, None, None, 10), - ('add_const', lambda x=1: x + 10, None, None, None, 11), - ( - 'add_const_with_type', - lambda x=1: x + 10, - computation_types.TensorType(np.int32), - None, - 20, - 30, - ), - ( - 'add', - lambda x, y: x + y, - computation_types.StructType([np.int32, np.int32]), - None, - structure.Struct([('x', 5), ('y', 6)]), - 11, - ), - ( - 'str_tuple', - lambda *args: str(args), - computation_types.StructType([np.int32, np.int32]), - True, - structure.Struct([(None, 5), (None, 6)]), - '(5, 6)', - ), - ( - 'str_tuple_with_named_type', - lambda *args: str(args), - computation_types.StructType([('x', np.int32), ('y', np.int32)]), - False, - structure.Struct([('x', 5), ('y', 6)]), - "(Struct([('x', 5), ('y', 6)]),)", - ), - ( - 'str_ing', - lambda x: str(x), # pylint: disable=unnecessary-lambda - computation_types.StructWithPythonType([np.int32], list), - None, - structure.Struct([(None, 10)]), - '[10]', - ), - ) - def test_wrap_as_zero_or_one_arg_callable( - self, fn, parameter_type, unpack, arg, expected_result - ): - wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable( - fn, parameter_type, unpack - ) - actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn() - self.assertEqual(actual_result, expected_result) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/computation/polymorphic_computation.py b/tensorflow_federated/python/core/impl/computation/polymorphic_computation.py deleted file mode 100644 index 1b70a79944..0000000000 --- a/tensorflow_federated/python/core/impl/computation/polymorphic_computation.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for Python functions, defuns, and other types of callables.""" - -from collections.abc import Callable -from typing import Optional - -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.types import computation_types - - -class PolymorphicComputation: - """A generic polymorphic function that accepts arguments of diverse types.""" - - def __init__( - self, - concrete_function_factory: Callable[ - [computation_types.Type, Optional[bool]], - computation_impl.ConcreteComputation, - ], - infer_type_fn: Callable[[object], computation_types.Type], - ): - """Crates a polymorphic function with a given function factory. - - Args: - concrete_function_factory: A callable that accepts a (non-None) TFF type - as an argument, as well as an optional boolean `unpack` argument which - should be treated as documented in - `function_utils.wrap_as_zero_or_one_arg_callable`. The callable must - return a `Computation` instance that's been created to accept a single - positional argument of this TFF type (to be reused for future calls with - parameters of a matching type). - infer_type_fn: A `Callable` used to convert a backend-specific value to a - `tff.Type`. - """ - self._concrete_function_factory = concrete_function_factory - self._infer_type_fn = infer_type_fn - self._concrete_function_cache = {} - - def fn_for_argument_type( - self, arg_type: computation_types.Type, unpack: Optional[bool] = None - ) -> computation_impl.ConcreteComputation: - """Concretizes this function with the provided `arg_type`. - - The first time this function is called with a particular type on a - given `PolymorphicComputation` (or this `PolymorphicComputation` is called - with an argument of the given type), the underlying function will be - traced using the provided argument type as input. Later calls will - return the cached computed concrete function. - - Args: - arg_type: The argument type to use when concretizing this function. - unpack: Whether to force unpacking the arguments (`True`), never unpack - the arguments (`False`), or infer whether or not to unpack the arguments - (`None`). - - Returns: - The `tff.framework.ConcreteComputation` that results from tracing this - `PolymorphicComputation` with `arg_type. - """ - key = repr(arg_type) + str(unpack) - concrete_fn = self._concrete_function_cache.get(key) - if not concrete_fn: - concrete_fn = (self._concrete_function_factory)(arg_type, unpack) - if concrete_fn.type_signature.parameter != arg_type: - raise TypeError( - 'Expected a concrete function that takes parameter {}, got one ' - 'that takes {}.'.format( - arg_type, concrete_fn.type_signature.parameter - ) - ) - self._concrete_function_cache[key] = concrete_fn - return concrete_fn - - def __call__(self, *args, **kwargs): - """Invokes this polymorphic function with a given set of arguments. - - Args: - *args: Positional args. - **kwargs: Keyword args. - - Returns: - The result of calling a concrete function, instantiated on demand based - on the argument types (and cached for future calls). - - Raises: - TypeError: if the concrete functions created by the factory are of the - wrong computation_types. - """ - packed_arg = function_utils.pack_args_into_struct(args, kwargs) - args_type = self._infer_type_fn(args) - if not isinstance(args_type, computation_types.StructType): - raise ValueError - kwargs_type = self._infer_type_fn(kwargs) - if not isinstance(kwargs_type, computation_types.StructType): - raise ValueError - arg_type = computation_types.StructType([ - *args_type.items(), - *kwargs_type.items(), - ]) - # We know the argument types have been packed, so force unpacking. - concrete_fn = self.fn_for_argument_type(arg_type, unpack=True) - return concrete_fn(packed_arg) diff --git a/tensorflow_federated/python/core/impl/computation/polymorphic_computation_test.py b/tensorflow_federated/python/core/impl/computation/polymorphic_computation_test.py deleted file mode 100644 index 84fd7239b3..0000000000 --- a/tensorflow_federated/python/core/impl/computation/polymorphic_computation_test.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import polymorphic_computation -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import type_serialization - - -class PolymorphicComputationTest(absltest.TestCase): - - def test_call_returns_result(self): - class TestContext(context_base.SyncContext): - - def ingest(self, val, type_spec): - del type_spec # Unused. - return val - - def invoke(self, comp, arg): - return 'name={},type={},arg={},unpack={}'.format( - comp.name, comp.type_signature.parameter, arg, comp.unpack - ) - - class TestContextStack(context_stack_base.ContextStack): - - def __init__(self): - super().__init__() - self._context = TestContext() - - @property - def current(self): - return self._context - - def install(self, ctx): - del ctx # Unused - return self._context - - context_stack = TestContextStack() - - class TestFunction(computation_impl.ConcreteComputation): - - def __init__(self, name, unpack, parameter_type): - self._name = name - self._unpack = unpack - type_signature = computation_types.FunctionType(parameter_type, np.str_) - test_proto = pb.Computation( - type=type_serialization.serialize_type(type_signature) - ) - super().__init__( - computation_proto=test_proto, - context_stack=context_stack, - annotated_type=type_signature, - ) - - @property - def name(self): - return self._name - - @property - def unpack(self): - return self._unpack - - class TestFunctionFactory: - - def __init__(self): - self._count = 0 - - def __call__(self, parameter_type, unpack): - self._count = self._count + 1 - return TestFunction(str(self._count), str(unpack), parameter_type) - - fn = polymorphic_computation.PolymorphicComputation( - TestFunctionFactory(), type_conversions.infer_type - ) - - self.assertEqual(fn(10), 'name=1,type=,arg=<10>,unpack=True') - self.assertEqual( - fn(20, x=True), 'name=2,type=,arg=<20,x=True>,unpack=True' - ) - fn_with_bool_arg = fn.fn_for_argument_type( - computation_types.TensorType(np.bool_) - ) - self.assertEqual( - fn_with_bool_arg(True), 'name=3,type=bool,arg=True,unpack=None' - ) - self.assertEqual( - fn(30, x=40), 'name=4,type=,arg=<30,x=40>,unpack=True' - ) - self.assertEqual(fn(50), 'name=1,type=,arg=<50>,unpack=True') - self.assertEqual( - fn(0, x=False), 'name=2,type=,arg=<0,x=False>,unpack=True' - ) - fn_with_bool_arg = fn.fn_for_argument_type( - computation_types.TensorType(np.bool_) - ) - self.assertEqual( - fn_with_bool_arg(False), 'name=3,type=bool,arg=False,unpack=None' - ) - self.assertEqual( - fn(60, x=70), 'name=4,type=,arg=<60,x=70>,unpack=True' - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/context_stack/BUILD b/tensorflow_federated/python/core/impl/context_stack/BUILD deleted file mode 100644 index 8ced9f4a1c..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -load("@rules_python//python:defs.bzl", "py_library", "py_test") - -package( - default_applicable_licenses = ["//:package_license"], - default_visibility = [ - ":context_stack_packages", - "//tensorflow_federated/python/core/impl:impl_users", - "//tensorflow_federated/python/core/impl/computation:computation_packages", - "//tensorflow_federated/python/core/impl/execution_contexts:execution_contexts_packages", - "//tensorflow_federated/python/core/impl/executors:executors_packages", - "//tensorflow_federated/python/core/impl/federated_context:federated_context_packages", - ], -) - -package_group( - name = "context_stack_packages", - packages = ["//tensorflow_federated/python/core/impl/context_stack/..."], -) - -licenses(["notice"]) - -py_library( - name = "context_stack", - srcs = ["__init__.py"], - visibility = ["//tools/python_package:python_package_tool"], -) - -py_library( - name = "context_base", - srcs = ["context_base.py"], -) - -py_library( - name = "context_stack_base", - srcs = ["context_stack_base.py"], -) - -py_library( - name = "context_stack_impl", - srcs = ["context_stack_impl.py"], - deps = [ - ":context_base", - ":context_stack_base", - ":runtime_error_context", - "//tensorflow_federated/python/common_libs:py_typecheck", - ], -) - -py_test( - name = "context_stack_impl_test", - size = "small", - srcs = ["context_stack_impl_test.py"], - deps = [ - ":context_stack_impl", - ":context_stack_test_utils", - ], -) - -py_library( - name = "context_stack_test_utils", - srcs = ["context_stack_test_utils.py"], - deps = [ - ":context_base", - ":context_stack_impl", - ], -) - -py_test( - name = "context_stack_test_utils_test", - srcs = ["context_stack_test_utils_test.py"], - deps = [ - ":context_stack_impl", - ":context_stack_test_utils", - ], -) - -py_library( - name = "get_context_stack", - srcs = ["get_context_stack.py"], - deps = [":context_stack_impl"], -) - -py_test( - name = "get_context_stack_test", - size = "small", - srcs = ["get_context_stack_test.py"], - deps = [ - ":context_stack_impl", - ":get_context_stack", - ], -) - -py_library( - name = "runtime_error_context", - srcs = ["runtime_error_context.py"], - deps = [":context_base"], -) - -py_library( - name = "set_default_context", - srcs = ["set_default_context.py"], - deps = [ - ":context_stack_impl", - ":runtime_error_context", - ], -) - -py_test( - name = "set_default_context_test", - size = "small", - srcs = ["set_default_context_test.py"], - deps = [ - ":context_stack_impl", - ":context_stack_test_utils", - ":set_default_context", - ], -) - -py_library( - name = "symbol_binding_context", - srcs = ["symbol_binding_context.py"], - deps = [":context_base"], -) diff --git a/tensorflow_federated/python/core/impl/context_stack/__init__.py b/tensorflow_federated/python/core/impl/context_stack/__init__.py deleted file mode 100644 index 111d2374c6..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Libraries for interacting with the context of a computation.""" diff --git a/tensorflow_federated/python/core/impl/context_stack/context_base.py b/tensorflow_federated/python/core/impl/context_stack/context_base.py deleted file mode 100644 index 14aa09c848..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/context_base.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines context interfaces which evaluates computation invocations. - -Invocations of TensorFlow Federated computations need to be treated differently -depending on the context in which they are invoked. For example: - -* During top-level Python simulations, computation invocations result in the - computation being serialized and evaluated by the TensorFlow native runtime. - -* In functions decorated with `@tff.tensorflow.computation`, computation - invocations must import the body of the invoked function into the current - TensorFlow graph. - -Code can customize the way in which each of these calls are evaluated by setting -a specific context using a global or thread-local context stack. -""" - -import abc -from typing import Any - - -class ContextError(RuntimeError): - pass - - -class SyncContext(metaclass=abc.ABCMeta): - """A synchronous context to evaluate of computations.""" - - @abc.abstractmethod - def invoke(self, comp: Any, arg: Any) -> Any: - """Invokes computation `comp` with argument `arg`. - - Args: - comp: The computation being invoked. The Python type of `comp` expected - here (e.g., `pb.Computation`. `tff.framework.ConcreteComputation`, or - other) may depend on the context. It is the responsibility of the - concrete implementation of this interface to verify that the type of - `comp` matches what the context is expecting. - arg: The argument passed to the computation. If no argument is passed, - this will be `None`. Structural argument types will be normalized into - `tff.structure.Struct`s. - - Returns: - The result of invocation, which is context-dependent. - """ - raise NotImplementedError - - -class AsyncContext(metaclass=abc.ABCMeta): - """An asynchronous context to evaluate of computations.""" - - @abc.abstractmethod - async def invoke(self, comp: Any, arg: Any) -> Any: - """Invokes computation `comp` with argument `arg`. - - Args: - comp: The computation being invoked. The Python type of `comp` expected - here (e.g., `pb.Computation`. `tff.framework.ConcreteComputation`, or - other) may depend on the context. It is the responsibility of the - concrete implementation of this interface to verify that the type of - `comp` matches what the context is expecting. - arg: The argument passed to the computation. If no argument is passed, - this will be `None`. Structural argument types will be normalized into - `tff.structure.Struct`s. - - Returns: - The result of invocation, which is context-dependent. - """ - raise NotImplementedError diff --git a/tensorflow_federated/python/core/impl/context_stack/context_stack_base.py b/tensorflow_federated/python/core/impl/context_stack/context_stack_base.py deleted file mode 100644 index 348f934fff..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/context_stack_base.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines the interface for the context stack.""" - -import abc -import contextlib - - -class ContextStack(metaclass=abc.ABCMeta): - """An interface to a context stack for the API to run against.""" - - @property - @abc.abstractmethod - def current(self): - """Returns the current context (one at the top of the context stack).""" - raise NotImplementedError - - @contextlib.contextmanager - @abc.abstractmethod - def install(self, ctx): - """A context manager that temporarily installs a new context on the stack. - - The installed context is placed at the top on the stack while in the context - manager's scope, and remove from the stack upon exiting the scope. This - method should only be used by the implementation code, and by the unit tests - for dependency injection. - - Args: - ctx: The context to temporarily install at the top of the context stack, - an instance of `Context` defined in `context_base.py`. - - Yields: - The installed context. - - Raises: - TypeError: If `ctx` is not a valid instance of - `tff.framework.AsyncContext` or `tff.framework.SyncContext`. - """ - raise NotImplementedError diff --git a/tensorflow_federated/python/core/impl/context_stack/context_stack_impl.py b/tensorflow_federated/python/core/impl/context_stack/context_stack_impl.py deleted file mode 100644 index cd3f38efed..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/context_stack_impl.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines classes/functions to manipulate the API context stack.""" - -from collections.abc import Generator -import contextlib -import threading -import typing -from typing import Union - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.context_stack import runtime_error_context - - -_Context = Union[context_base.SyncContext, context_base.AsyncContext] - - -class ContextStackImpl(context_stack_base.ContextStack, threading.local): - """An implementation of a common thread-local context stack to run against.""" - - def __init__(self, default_context): - super().__init__() - self._stack = [default_context] - - def set_default_context(self, ctx: _Context) -> None: - """Places `ctx` at the bottom of the stack. - - Args: - ctx: An instance of `tff.framework.AsyncContext` or - `tff.framework.AsyncContext`. - """ - py_typecheck.check_type(ctx, typing.get_args(_Context)) - assert self._stack - self._stack[0] = ctx - - @property - def current(self) -> _Context: - assert self._stack - ctx = self._stack[-1] - assert isinstance(ctx, typing.get_args(_Context)) - return ctx - - @contextlib.contextmanager - def install(self, ctx: _Context) -> Generator[_Context, None, None]: - py_typecheck.check_type(ctx, typing.get_args(_Context)) - self._stack.append(ctx) - try: - yield ctx - finally: - self._stack.pop() - - -context_stack = ContextStackImpl(runtime_error_context.RuntimeErrorContext()) diff --git a/tensorflow_federated/python/core/impl/context_stack/context_stack_impl_test.py b/tensorflow_federated/python/core/impl/context_stack/context_stack_impl_test.py deleted file mode 100644 index 0d577b5ade..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/context_stack_impl_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils - - -class ContextStackTest(absltest.TestCase): - - def test_set_default_context_with_context(self): - default_context = context_stack_test_utils.TestContext() - context_stack = context_stack_impl.ContextStackImpl(default_context) - context = context_stack_test_utils.TestContext() - self.assertIsNot(context_stack.current, context) - - context_stack.set_default_context(context) - - self.assertIs(context_stack.current, context) - - def test_set_default_context_raises_type_error_with_none(self): - default_context = context_stack_test_utils.TestContext() - context_stack = context_stack_impl.ContextStackImpl(default_context) - - with self.assertRaises(TypeError): - context_stack.set_default_context(None) - - def test_install_pushes_context_on_stack(self): - default_context = context_stack_test_utils.TestContext() - context_stack = context_stack_impl.ContextStackImpl(default_context) - self.assertIs(context_stack.current, default_context) - - context_two = context_stack_test_utils.TestContext() - with context_stack.install(context_two): - self.assertIs(context_stack.current, context_two) - - context_three = context_stack_test_utils.TestContext() - with context_stack.install(context_three): - self.assertIs(context_stack.current, context_three) - - self.assertIs(context_stack.current, context_two) - - self.assertIs(context_stack.current, default_context) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/context_stack/context_stack_test_utils.py b/tensorflow_federated/python/core/impl/context_stack/context_stack_test_utils.py deleted file mode 100644 index 0b46d3043d..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/context_stack_test_utils.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for testing context stacks.""" - -import asyncio -from collections.abc import Callable, Iterable -import contextlib -import functools -from typing import Optional, Union - -from absl.testing import parameterized - -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl - -_Context = Union[context_base.AsyncContext, context_base.SyncContext] -_ContextFactory = Callable[[], _Context] -_EnvironmentFactory = Callable[ - [], Iterable[contextlib.AbstractContextManager[None]] -] - - -class TestContext(context_base.SyncContext): - """A test context.""" - - def invoke(self, comp, arg): - return NotImplementedError - - -@contextlib.contextmanager -def test_environment(): - yield None - - -def with_context( - context_fn: _ContextFactory, - environment_fn: Optional[_EnvironmentFactory] = None, -): - """Returns a decorator for running a test in a context. - - Args: - context_fn: A `Callable` that constructs a `tff.framework.AsyncContext` or - `tff.framework.SyncContext` to install beore invoking the decorated - function. - environment_fn: A `Callable` that constructs a list of - `contextlib.AbstractContextManager` to enter before invoking the decorated - function. - """ - - def decorator(fn): - - @contextlib.contextmanager - def install_context( - context_fn: _ContextFactory, - environment_fn: Optional[_EnvironmentFactory] = None, - ): - context = context_fn() - with context_stack_impl.context_stack.install(context): - if environment_fn is not None: - with contextlib.ExitStack() as stack: - context_managers = environment_fn() - for context_manager in context_managers: - stack.enter_context(context_manager) - yield - else: - yield - - if asyncio.iscoroutinefunction(fn): - - @functools.wraps(fn) - async def wrapper(*args, **kwargs): - with install_context(context_fn, environment_fn): - return await fn(*args, **kwargs) - - else: - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - with install_context(context_fn, environment_fn): - return fn(*args, **kwargs) - - return wrapper - - return decorator - - -def with_contexts(*named_contexts): - """Returns a decorator for parameterizing a test by a context. - - Args: - *named_contexts: Named parameters used to construct the `with_context` - decorator; either a single iterable, or a list of `tuple`s or `dict`s. - - Raises: - ValueError: If no named contexts are passed to the decorator. - """ - if not named_contexts: - raise ValueError('Expected at least one named parameter, found none.') - - def decorator(fn): - - if asyncio.iscoroutinefunction(fn): - - @functools.wraps(fn) - @parameterized.named_parameters(*named_contexts) - async def wrapper( - self, - context_fn: _ContextFactory, - environment_fn: Optional[_EnvironmentFactory] = None, - ): - decorator = with_context(context_fn, environment_fn) - decorated_fn = decorator(fn) - await decorated_fn(self) - - else: - - @functools.wraps(fn) - @parameterized.named_parameters(*named_contexts) - def wrapper( - self, - context_fn: _ContextFactory, - environment_fn: Optional[_EnvironmentFactory] = None, - ): - decorator = with_context(context_fn, environment_fn) - decorated_fn = decorator(fn) - decorated_fn(self) - - return wrapper - - return decorator diff --git a/tensorflow_federated/python/core/impl/context_stack/context_stack_test_utils_test.py b/tensorflow_federated/python/core/impl/context_stack/context_stack_test_utils_test.py deleted file mode 100644 index e82adffe71..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/context_stack_test_utils_test.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2022, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import contextlib -import unittest -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized - -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils - - -class WithContextTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): - - @parameterized.named_parameters( - ( - 'context_and_no_environment', - context_stack_test_utils.TestContext(), - None, - ), - ( - 'context_and_empty_environment', - context_stack_test_utils.TestContext(), - [], - ), - ( - 'context_and_1_environment', - context_stack_test_utils.TestContext(), - [context_stack_test_utils.test_environment()], - ), - ( - 'context_and_3_environment', - context_stack_test_utils.TestContext(), - [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ], - ), - ) - def test_installs_context_fn_sync_no_arg(self, context, environments): - context_fn = lambda: context - environment_fn = None if environments is None else lambda: environments - - @context_stack_test_utils.with_context(context_fn, environment_fn) - def _foo(): - self.assertEqual(context_stack_impl.context_stack.current, context) - - # Assert that a sync function is returned. - self.assertFalse(asyncio.iscoroutinefunction(_foo)) - - with mock.patch.object( - contextlib.ExitStack, 'enter_context' - ) as mock_enter_context: - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - _foo() - - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that `enter_context` is called with the expected environment. - if environments is not None: - calls = [mock.call(e) for e in environments] - self.assertEqual(mock_enter_context.mock_calls, calls) - - @parameterized.named_parameters( - ( - 'context_and_no_environment', - context_stack_test_utils.TestContext(), - None, - ), - ( - 'context_and_empty_environment', - context_stack_test_utils.TestContext(), - [], - ), - ( - 'context_and_1_environment', - context_stack_test_utils.TestContext(), - [context_stack_test_utils.test_environment()], - ), - ( - 'context_and_3_environment', - context_stack_test_utils.TestContext(), - [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ], - ), - ) - - def test_installs_context_fn_sync_args(self, context, environments): - context_fn = lambda: context - environment_fn = None if environments is None else lambda: environments - - @context_stack_test_utils.with_context(context_fn, environment_fn) - def _foo(x): - del x # Unused. - self.assertEqual(context_stack_impl.context_stack.current, context) - - # Assert that a sync function is returned. - self.assertFalse(asyncio.iscoroutinefunction(_foo)) - - with mock.patch.object( - contextlib.ExitStack, 'enter_context' - ) as mock_enter_context: - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - _foo(1) - - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that `enter_context` is called with the expected environment. - if environments is not None: - calls = [mock.call(e) for e in environments] - self.assertEqual(mock_enter_context.mock_calls, calls) - - @parameterized.named_parameters( - ( - 'context_and_no_environment', - context_stack_test_utils.TestContext(), - None, - ), - ( - 'context_and_empty_environment', - context_stack_test_utils.TestContext(), - [], - ), - ( - 'context_and_1_environment', - context_stack_test_utils.TestContext(), - [context_stack_test_utils.test_environment()], - ), - ( - 'context_and_3_environment', - context_stack_test_utils.TestContext(), - [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ], - ), - ) - - def test_installs_context_fn_sync_kwargs(self, context, environments): - context_fn = lambda: context - environment_fn = None if environments is None else lambda: environments - - @context_stack_test_utils.with_context(context_fn, environment_fn) - def _foo(*, x): - del x # Unused. - self.assertEqual(context_stack_impl.context_stack.current, context) - - # Assert that a sync function is returned. - self.assertFalse(asyncio.iscoroutinefunction(_foo)) - - with mock.patch.object( - contextlib.ExitStack, 'enter_context' - ) as mock_enter_context: - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - _foo(x=1) - - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that `enter_context` is called with the expected environment. - if environments is not None: - calls = [mock.call(e) for e in environments] - self.assertEqual(mock_enter_context.mock_calls, calls) - - @parameterized.named_parameters( - ( - 'context_and_no_environment', - context_stack_test_utils.TestContext(), - None, - ), - ( - 'context_and_empty_environment', - context_stack_test_utils.TestContext(), - [], - ), - ( - 'context_and_1_environment', - context_stack_test_utils.TestContext(), - [context_stack_test_utils.test_environment()], - ), - ( - 'context_and_3_environment', - context_stack_test_utils.TestContext(), - [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ], - ), - ) - - def test_installs_context_fn_sync_return(self, context, environments): - context_fn = lambda: context - environment_fn = None if environments is None else lambda: environments - - @context_stack_test_utils.with_context(context_fn, environment_fn) - def _foo(): - self.assertEqual(context_stack_impl.context_stack.current, context) - return 1 - - # Assert that a sync function is returned. - self.assertFalse(asyncio.iscoroutinefunction(_foo)) - - with mock.patch.object( - contextlib.ExitStack, 'enter_context' - ) as mock_enter_context: - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - x = _foo() - - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that `enter_context` is called with the expected environment. - if environments is not None: - calls = [mock.call(e) for e in environments] - self.assertEqual(mock_enter_context.mock_calls, calls) - - # Assert that the return value is returned by the decorator. - self.assertEqual(x, 1) - - @parameterized.named_parameters( - ( - 'context_and_no_environment', - context_stack_test_utils.TestContext(), - None, - ), - ( - 'context_and_empty_environment', - context_stack_test_utils.TestContext(), - [], - ), - ( - 'context_and_1_environment', - context_stack_test_utils.TestContext(), - [context_stack_test_utils.test_environment()], - ), - ( - 'context_and_3_environment', - context_stack_test_utils.TestContext(), - [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ], - ), - ) - - async def test_installs_context_fn_async(self, context, environments): - context_fn = lambda: context - environment_fn = None if environments is None else lambda: environments - - @context_stack_test_utils.with_context(context_fn, environment_fn) - async def _foo(): - self.assertEqual(context_stack_impl.context_stack.current, context) - - # Assert that an async function is returned. - self.assertTrue(asyncio.iscoroutinefunction(_foo)) - - with mock.patch.object( - contextlib.ExitStack, 'enter_context' - ) as mock_enter_context: - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - await _foo() - - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that `enter_context` is called with the expected environment. - if environments is not None: - calls = [mock.call(e) for e in environments] - self.assertEqual(mock_enter_context.mock_calls, calls) - - @parameterized.named_parameters( - ( - 'context_and_no_environment', - context_stack_test_utils.TestContext(), - None, - ), - ( - 'context_and_empty_environment', - context_stack_test_utils.TestContext(), - [], - ), - ( - 'context_and_1_environment', - context_stack_test_utils.TestContext(), - [context_stack_test_utils.test_environment()], - ), - ( - 'context_and_3_environments', - context_stack_test_utils.TestContext(), - [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ], - ), - ) - - def test_installs_context_test_case(self, context, environments): - context_fn = lambda: context - environment_fn = None if environments is None else lambda: environments - - class _FooTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): - - @context_stack_test_utils.with_context(context_fn, environment_fn) - async def test_async(self): - self.assertEqual(context_stack_impl.context_stack.current, context) - - @context_stack_test_utils.with_context(context_fn, environment_fn) - def test_sync(self): - self.assertEqual(context_stack_impl.context_stack.current, context) - - def test_undecorated(self): - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that a sync function is returned. - self.assertFalse(asyncio.iscoroutinefunction(_FooTest.test_sync)) - - # Assert that an async function is returned. - self.assertTrue(asyncio.iscoroutinefunction(_FooTest.test_async)) - - with mock.patch.object( - contextlib.ExitStack, 'enter_context' - ) as mock_enter_context: - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that the test passes with the expected number of test cases. - suite = unittest.defaultTestLoader.loadTestsFromTestCase(_FooTest) - self.assertEqual(suite.countTestCases(), 3) - runner = unittest.TextTestRunner() - result = runner.run(suite) - self.assertEqual(result.testsRun, 3) - self.assertTrue(result.wasSuccessful()) - - # Assert that the context is not installed. - self.assertNotEqual(context_stack_impl.context_stack.current, context) - - # Assert that `enter_context` is called with the expected environment. - if environments is not None: - calls = [mock.call(e) for e in environments] * 2 - self.assertEqual(mock_enter_context.mock_calls, calls) - - -class WithContextsTest(parameterized.TestCase): - - def test_installs_contexts_test_case(self): - def _context_fn(): - return context_stack_test_utils.TestContext() - - def _environment_fn(): - return [ - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - context_stack_test_utils.test_environment(), - ] - - named_contexts = [ - ('1', _context_fn, _environment_fn), - ('2', _context_fn, _environment_fn), - ('3', _context_fn, _environment_fn), - ] - - class _FooTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): - - @context_stack_test_utils.with_contexts(*named_contexts) - async def test_async(self): - pass - - @context_stack_test_utils.with_contexts(*named_contexts) - def test_sync(self): - pass - - def test_undecorated(self): - pass - - # Assert that a sync function is returned. - for name, _, _ in named_contexts: - test_name = f'test_sync_{name}' - self.assertTrue(hasattr(_FooTest, test_name)) - test_fn = getattr(_FooTest, test_name) - self.assertFalse(asyncio.iscoroutinefunction(test_fn)) - - # Assert that an async function is returned. - for name, _, _ in named_contexts: - test_name = f'test_async_{name}' - self.assertTrue(hasattr(_FooTest, test_name)) - test_fn = getattr(_FooTest, test_name) - self.assertTrue(asyncio.iscoroutinefunction(test_fn)) - - async_values = [lambda _: mock.AsyncMock()] * 3 - sync_values = [lambda _: mock.MagicMock()] * 3 - with mock.patch.object( - context_stack_test_utils, - 'with_context', - side_effect=async_values + sync_values, - ) as mock_with_context: - # Assert that the test passes with the expected number of test cases. - suite = unittest.defaultTestLoader.loadTestsFromTestCase(_FooTest) - self.assertEqual(suite.countTestCases(), len(named_contexts) * 2 + 1) - runner = unittest.TextTestRunner() - result = runner.run(suite) - self.assertEqual(result.testsRun, len(named_contexts) * 2 + 1) - self.assertTrue(result.wasSuccessful()) - - # Assert that `with_context` is called with the expected parameters. - calls = [mock.call(a, b) for _, a, b in named_contexts] * 2 - self.assertEqual(mock_with_context.mock_calls, calls) - - def test_raises_value_error(self): - with self.assertRaises(ValueError): - context_stack_test_utils.with_contexts() - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/context_stack/get_context_stack.py b/tensorflow_federated/python/core/impl/context_stack/get_context_stack.py deleted file mode 100644 index 576fdc2837..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/get_context_stack.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A utility to get the context stack.""" - -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl - - -def get_context_stack(): - """Returns the context stack.""" - return context_stack_impl.context_stack diff --git a/tensorflow_federated/python/core/impl/context_stack/get_context_stack_test.py b/tensorflow_federated/python/core/impl/context_stack/get_context_stack_test.py deleted file mode 100644 index 61f4c11d77..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/get_context_stack_test.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import get_context_stack - - -class GetContextStackTest(absltest.TestCase): - - def test_returns_context(self): - context_stack = get_context_stack.get_context_stack() - self.assertIsInstance(context_stack, context_stack_impl.ContextStackImpl) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/context_stack/runtime_error_context.py b/tensorflow_federated/python/core/impl/context_stack/runtime_error_context.py deleted file mode 100644 index 78edd160ad..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/runtime_error_context.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines classes/functions to manipulate the API context stack.""" - -from tensorflow_federated.python.core.impl.context_stack import context_base - - -class RuntimeErrorContext(context_base.SyncContext): - """A context that will fail if you execute against it.""" - - def _raise_runtime_error(self): - raise RuntimeError( - 'No default context installed.\n' - '\n' - 'You should not expect to get this error using the TFF API.\n' - '\n' - 'If you are getting this error when testing a module inside of ' - '`tensorflow_federated/python/core/...`, you may need to explicitly ' - 'invoke `execution_contexts.set_sync_local_cpp_execution_context()` in ' - 'the `main` function of your test.' - ) - - def invoke(self, comp, arg): - del comp # Unused - del arg # Unused - self._raise_runtime_error() - - -def create_runtime_error_context() -> RuntimeErrorContext: - """Creates a context that will raise an error when computations are invoked.""" - return RuntimeErrorContext() diff --git a/tensorflow_federated/python/core/impl/context_stack/set_default_context.py b/tensorflow_federated/python/core/impl/context_stack/set_default_context.py deleted file mode 100644 index b368cb55c9..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/set_default_context.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A utility to change the context stack.""" - -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import runtime_error_context - - -def set_default_context(ctx): - """Places `ctx` at the bottom of the stack. - - Args: - ctx: A `tff.framework.AsyncContext` or `tff.framework.SyncContext`. - """ - context_stack_impl.context_stack.set_default_context(ctx) - - -def set_no_default_context(): - """Places a `RuntimeErrorContext` at the bottom of the stack.""" - set_default_context(runtime_error_context.RuntimeErrorContext()) diff --git a/tensorflow_federated/python/core/impl/context_stack/set_default_context_test.py b/tensorflow_federated/python/core/impl/context_stack/set_default_context_test.py deleted file mode 100644 index c0c2d380ef..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/set_default_context_test.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils -from tensorflow_federated.python.core.impl.context_stack import set_default_context - - -class SetDefaultContextTest(absltest.TestCase): - - def setUp(self): - super().setUp() - # In these tests we are setting the default context of the - # `context_stack_impl.context_stack`, so here we reset that context back to - # some known state. - self.context = context_stack_test_utils.TestContext() - context_stack_impl.context_stack.set_default_context(self.context) - - def test_with_context(self): - context = context_stack_test_utils.TestContext() - context_stack = context_stack_impl.context_stack - self.assertIsNot(context_stack.current, context) - - set_default_context.set_default_context(context) - - self.assertIs(context_stack.current, context) - - def test_raises_type_error_with_none(self): - with self.assertRaises(TypeError): - set_default_context.set_default_context(None) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/context_stack/symbol_binding_context.py b/tensorflow_federated/python/core/impl/context_stack/symbol_binding_context.py deleted file mode 100644 index c0ea24ccfd..0000000000 --- a/tensorflow_federated/python/core/impl/context_stack/symbol_binding_context.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines interface for contexts which can bind symbols.""" - -import abc -from typing import Generic, TypeVar - -from tensorflow_federated.python.core.impl.context_stack import context_base - - -_Symbol = TypeVar('_Symbol') -_Reference = TypeVar('_Reference') - - -class SymbolBindingContext( - context_base.SyncContext, abc.ABC, Generic[_Symbol, _Reference] -): - """Interface for contexts which handle binding and tracking of references.""" - - @abc.abstractmethod - def bind_computation_to_reference(self, comp: _Symbol) -> _Reference: - """Binds a computation to a symbol, returns a reference to this binding.""" - raise NotImplementedError - - @property - @abc.abstractmethod - def symbol_bindings(self) -> list[tuple[str, _Symbol]]: - """Returns all symbols bound in this context.""" - raise NotImplementedError diff --git a/tensorflow_federated/python/core/impl/execution_contexts/BUILD b/tensorflow_federated/python/core/impl/execution_contexts/BUILD index facb1d1581..e0738b73e7 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/BUILD +++ b/tensorflow_federated/python/core/impl/execution_contexts/BUILD @@ -21,76 +21,14 @@ py_library( visibility = ["//tools/python_package:python_package_tool"], ) -py_library( - name = "async_execution_context", - srcs = ["async_execution_context.py"], - deps = [ - ":compiler_pipeline", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:retrying", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/common_libs:tracing", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/executors:cardinalities_utils", - "//tensorflow_federated/python/core/impl/executors:executor_base", - "//tensorflow_federated/python/core/impl/executors:executor_factory", - "//tensorflow_federated/python/core/impl/executors:executor_value_base", - "//tensorflow_federated/python/core/impl/executors:executors_errors", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:typed_object", - ], -) - -py_test( - name = "async_execution_context_test", - size = "small", - srcs = ["async_execution_context_test.py"], - deps = [ - ":async_execution_context", - "//tensorflow_federated/python/core/impl/executors:executors_errors", - ], -) - -py_library( - name = "compiler_pipeline", - srcs = ["compiler_pipeline.py"], - deps = [ - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/computation:computation_base", - ], -) - -py_test( - name = "compiler_pipeline_test", - size = "small", - srcs = ["compiler_pipeline_test.py"], - deps = [ - ":compiler_pipeline", - "//tensorflow_federated/python/core/impl/computation:computation_base", - ], -) - py_library( name = "mergeable_comp_execution_context", srcs = ["mergeable_comp_execution_context.py"], deps = [ - ":compiler_pipeline", "//tensorflow_federated/python/common_libs:async_utils", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/executors:cardinalities_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:typed_object", + "@federated_language//federated_language", ], ) @@ -102,21 +40,6 @@ py_test( deps = [ ":mergeable_comp_execution_context", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "sync_execution_context", - srcs = ["sync_execution_context.py"], - deps = [ - ":async_execution_context", - "//tensorflow_federated/python/common_libs:async_utils", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/executors:cardinalities_utils", - "//tensorflow_federated/python/core/impl/executors:executor_factory", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/async_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/async_execution_context.py deleted file mode 100644 index 2c26bacb2d..0000000000 --- a/tensorflow_federated/python/core/impl/execution_contexts/async_execution_context.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A context for execution based on an embedded executor instance.""" - -import asyncio -from collections.abc import Callable, Mapping, Sequence -import contextlib -from typing import Generic, Optional, TypeVar - -import tree - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import retrying -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.common_libs import tracing -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.execution_contexts import compiler_pipeline -from tensorflow_federated.python.core.impl.executors import cardinalities_utils -from tensorflow_federated.python.core.impl.executors import executor_base -from tensorflow_federated.python.core.impl.executors import executor_factory -from tensorflow_federated.python.core.impl.executors import executor_value_base -from tensorflow_federated.python.core.impl.executors import executors_errors -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import typed_object - - -_Computation = TypeVar('_Computation', bound=computation_base.Computation) - - -def _is_retryable_error(exception): - return isinstance(exception, executors_errors.RetryableError) - - -class AsyncExecutionContextValue(typed_object.TypedObject): - """Wrapper class for values produced by `ExecutionContext`.""" - - def __init__(self, value, type_spec): - py_typecheck.check_type(type_spec, computation_types.Type) - self._value = value - self._type_spec = type_spec - - @property - def type_signature(self): - return self._type_spec - - @property - def value(self): - return self._value - - -async def _ingest(executor, val, type_spec): - """A coroutine that handles ingestion. - - Args: - executor: An instance of `executor_base.Executor`. - val: The first argument to `AsyncExecutionContext.ingest()`. - type_spec: The second argument to `AsyncExecutionContext.ingest()`. - - Returns: - The result of the ingestion. - - Raises: - TypeError: If the `val` is not a value of type `type_spec`. - """ - if isinstance(val, executor_value_base.ExecutorValue): - return val - elif isinstance(val, structure.Struct) and not isinstance( - type_spec, computation_types.FederatedType - ): - if not isinstance(type_spec, computation_types.StructType): - raise ValueError(f'Expected a `tff.StructType`, found {type_spec}.') - v_elem = structure.to_elements(val) - t_elem = list(type_spec.items()) - if len(v_elem) != len(t_elem): - raise TypeError( - 'Value {} does not match type {}: mismatching tuple length.'.format( - val, type_spec - ) - ) - for (vk, _), (tk, _) in zip(v_elem, t_elem): - if vk not in [tk, None]: - raise TypeError( - 'Value {} does not match type {}: mismatching tuple element ' - 'names {} vs. {}.'.format(val, type_spec, vk, tk) - ) - ingested = [] - for (_, v), (_, t) in zip(v_elem, t_elem): - ingested.append(_ingest(executor, v, t)) - ingested = await asyncio.gather(*ingested) - return await executor.create_struct( - structure.Struct( - (name, val) for (name, _), val in zip(t_elem, ingested) - ) - ) - else: - return await executor.create_value(val, type_spec) - - -async def _invoke(executor, comp, arg, result_type: computation_types.Type): - """A coroutine that handles invocation. - - Args: - executor: An instance of `executor_base.Executor`. - comp: The first argument to `AsyncExecutionContext.invoke()`. - arg: The optional second argument to `AsyncExecutionContext.invoke()`. - result_type: The type signature of the result. This is used to convert the - execution result into the proper container types. - - Returns: - The result of the invocation. - """ - if arg is not None: - py_typecheck.check_type(arg, executor_value_base.ExecutorValue) - comp = await executor.create_value(comp, comp.type_signature) - result = await executor.create_call(comp, arg) - py_typecheck.check_type(result, executor_value_base.ExecutorValue) - result_value = await result.compute() - return type_conversions.type_to_py_container(result_value, result_type) - - -class AsyncExecutionContext(context_base.AsyncContext, Generic[_Computation]): - """An asynchronous execution context backed by an `executor_base.Executor`. - - This context's `ingest` and `invoke` methods return Python coroutine objects - which represent the actual work of ingestion and invocation in the backing - executor. - - This context will support concurrent invocation of multiple computations if - their arguments have the same cardinalities. - """ - - def __init__( - self, - executor_fn: executor_factory.ExecutorFactory, - compiler_fn: Optional[Callable[[_Computation], object]] = None, - *, - transform_args: Optional[Callable[[object], object]] = None, - transform_result: Optional[Callable[[object], object]] = None, - cardinality_inference_fn: cardinalities_utils.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities, - ): - """Initializes an execution context. - - Args: - executor_fn: Instance of `executor_factory.ExecutorFactory`. - compiler_fn: A Python function that will be used to compile a computation. - transform_args: An `Optional` `Callable` used to transform the args before - they are passed to the computation. - transform_result: An `Optional` `Callable` used to transform the result - before it is returned. - cardinality_inference_fn: A Python function specifying how to infer - cardinalities from arguments (and their associated types). The value - returned by this function will be passed to the `create_executor` method - of `executor_fn` to construct a `tff.framework.Executor` instance. - """ - super().__init__() - py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory) - self._executor_factory = executor_fn - if compiler_fn is not None: - self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn) - else: - self._compiler_pipeline = None - self._transform_args = transform_args - self._transform_result = transform_result - self._cardinality_inference_fn = cardinality_inference_fn - - @contextlib.contextmanager - def _reset_factory_on_error(self, ex_factory, cardinalities): - try: - # We pass a copy down to prevent the caller from mutating. - yield ex_factory.create_executor({**cardinalities}) - except Exception: - ex_factory.clean_up_executor({**cardinalities}) - raise - - @property - def executor_factory(self) -> executor_factory.ExecutorFactory: - return self._executor_factory - - @retrying.retry( - retry_on_exception_filter=_is_retryable_error, - wait_max_ms=30 * 1000, - wait_multiplier=2, - ) - async def invoke(self, comp, arg): - if asyncio.iscoroutine(arg): - # Awaiting if we are passed a coro allows us to install and use the async - # context in conjunction with ConcreteComputations' implementation of - # __call__. - arg = await arg - - if not isinstance(comp.type_signature, computation_types.FunctionType): - raise ValueError( - f'Expected a `tff.FunctionType`, found {comp.type_signature}.' - ) - - if arg is not None and self._transform_args is not None: - # `transform_args` is not intended to handle `tff.structure.Struct`. - # Normalize to a Python structure to make it simpler to handle; `args` is - # sometimes a `tff.structure.Struct` and sometimes it is not, other times - # it is a Python structure that contains a `tff.structure.Struct`. - def _to_python(obj): - if isinstance(obj, structure.Struct): - return structure.to_odict_or_tuple(obj) - else: - return None - - if isinstance(arg, structure.Struct): - args, kwargs = function_utils.unpack_args_from_struct(arg) - args = tree.traverse(_to_python, args) - args = self._transform_args(args) - if not isinstance(args, Sequence): - raise ValueError( - f'Expected `args` to be a `Sequence`, found {type(args)}' - ) - kwargs = tree.traverse(_to_python, kwargs) - kwargs = self._transform_args(kwargs) - if not isinstance(kwargs, Mapping): - raise ValueError( - f'Expected `kwargs` to be a `Mapping`, found {type(kwargs)}' - ) - arg = function_utils.pack_args_into_struct(args, kwargs) - else: - arg = tree.traverse(_to_python, arg) - arg = self._transform_args(arg) - - # Save the type signature before compiling. Compilation currently loses - # container types, so we must remember them here so that they can be - # restored in the output. - result_type = comp.type_signature.result - if self._compiler_pipeline is not None: - with tracing.span('ExecutionContext', 'Compile', span=True): - comp = self._compiler_pipeline.compile(comp) - - with tracing.span('ExecutionContext', 'Invoke', span=True): - if arg is not None: - cardinalities = self._cardinality_inference_fn( - arg, comp.type_signature.parameter - ) - else: - cardinalities = {} - - with self._reset_factory_on_error( - self._executor_factory, cardinalities - ) as executor: - py_typecheck.check_type(executor, executor_base.Executor) - - if arg is not None: - arg = await tracing.wrap_coroutine_in_current_trace_context( - _ingest(executor, arg, comp.type_signature.parameter) - ) - - result = await tracing.wrap_coroutine_in_current_trace_context( - _invoke(executor, comp, arg, result_type) - ) - - if self._transform_result is not None: - result = self._transform_result(result) - return result diff --git a/tensorflow_federated/python/core/impl/execution_contexts/async_execution_context_test.py b/tensorflow_federated/python/core/impl/execution_contexts/async_execution_context_test.py deleted file mode 100644 index 371a0f74c2..0000000000 --- a/tensorflow_federated/python/core/impl/execution_contexts/async_execution_context_test.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.executors import executors_errors - - -class RetryableErrorTest(absltest.TestCase): - - def test_is_retryable_error(self): - retryable_error = executors_errors.RetryableError() - self.assertTrue( - async_execution_context._is_retryable_error(retryable_error) - ) - self.assertFalse(async_execution_context._is_retryable_error(TypeError())) - self.assertFalse(async_execution_context._is_retryable_error(1)) - self.assertFalse(async_execution_context._is_retryable_error('a')) - self.assertFalse(async_execution_context._is_retryable_error(None)) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/execution_contexts/compiler_pipeline.py b/tensorflow_federated/python/core/impl/execution_contexts/compiler_pipeline.py deleted file mode 100644 index 31ed428a84..0000000000 --- a/tensorflow_federated/python/core/impl/execution_contexts/compiler_pipeline.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A pipeline that reduces computations into an executable form.""" - -from collections.abc import Callable -import functools -from typing import Generic, TypeVar - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base - -_Computation = TypeVar('_Computation', bound=computation_base.Computation) - - -class CompilerPipeline(Generic[_Computation]): - """An interface for generating executable artifacts. - - The `CompilerPipeline` holds very little logic; caching for the - artifacts it generates and essentially nothing else. The `CompilerPipeline` - is initialized with a `compiler_fn`, to which the pipeline itself delegates - all the actual work of compilation. - - Different TFF backends may accept different executable artifacts; e.g. a - backend that supports only a map-reduce execution model may accept instances - of `tff.backends.mapreduce.MapReduceForm`. The TFF representation of such a - backend takes the form of an instance of `tff.framework.SyncContext` or - `tff.framework.AsyncContext`, which would be initialized with a - `CompilerPipeline` whose `compilation_fn` accepts a `tff.Computation` and - returns `tff.backends.mapreduce.MapReduceForm`s. - """ - - def __init__(self, compiler_fn: Callable[[_Computation], object]): - self._compiler_fn = compiler_fn - - @functools.lru_cache() - def compile(self, comp: _Computation) -> object: - """Compiles `comp`.""" - py_typecheck.check_type(comp, computation_base.Computation) - return self._compiler_fn(comp) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/compiler_pipeline_test.py b/tensorflow_federated/python/core/impl/execution_contexts/compiler_pipeline_test.py deleted file mode 100644 index 025ee76067..0000000000 --- a/tensorflow_federated/python/core/impl/execution_contexts/compiler_pipeline_test.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.execution_contexts import compiler_pipeline - - -class _FakeComputation(computation_base.Computation): - - def __init__(self, value: int): - self.value = value - - def __call__(self): - raise NotImplementedError() - - def __hash__(self): - return hash(self.value) - - def type_signature(self): - raise NotImplementedError() - - -class CompilerPipelineTest(absltest.TestCase): - - def test_compile_computation_with_identity(self): - comp = _FakeComputation(5) - pipeline = compiler_pipeline.CompilerPipeline(lambda x: x) - - compiled_comp = pipeline.compile(comp) - self.assertEqual(compiled_comp.value, 5) - - # TODO: b/113123410 - Expand the test with more structural invariants. - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py index 537bdb0bb8..3bdeab028b 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context.py @@ -20,25 +20,17 @@ from typing import Generic, Optional, TypeVar, Union import attrs +import federated_language from tensorflow_federated.python.common_libs import async_utils from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.execution_contexts import compiler_pipeline -from tensorflow_federated.python.core.impl.executors import cardinalities_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import typed_object # Type alias for the payload value in a partitioned data structure. Value = TypeVar('Value') -_Computation = TypeVar('_Computation', bound=computation_base.Computation) +_Computation = TypeVar( + '_Computation', bound=federated_language.framework.Computation +) class MergeTypeNotAssignableError(TypeError): @@ -118,15 +110,16 @@ def after_merge(original_arg, merged_result): def __init__( self, *, - up_to_merge: computation_base.Computation, - merge: computation_base.Computation, - after_merge: computation_base.Computation, + up_to_merge: federated_language.framework.Computation, + merge: federated_language.framework.Computation, + after_merge: federated_language.framework.Computation, ): if not ( isinstance( - up_to_merge.type_signature.result, computation_types.FederatedType + up_to_merge.type_signature.result, federated_language.FederatedType ) - and up_to_merge.type_signature.result.placement is placements.SERVER + and up_to_merge.type_signature.result.placement + is federated_language.SERVER ): raise UpToMergeTypeError( 'Expected `up_to_merge` to return a single `tff.SERVER`-placed ' @@ -135,7 +128,7 @@ def __init__( # TFF's StructType assignability relation ensures that an unnamed struct can # be assigned to any struct with names. - expected_merge_param_type = computation_types.StructType([ + expected_merge_param_type = federated_language.StructType([ (None, up_to_merge.type_signature.result.member), # pytype: disable=attribute-error (None, up_to_merge.type_signature.result.member), # pytype: disable=attribute-error ]) @@ -144,10 +137,10 @@ def __init__( ): # pytype: disable=attribute-error raise MergeTypeNotAssignableError( 'Type mismatch checking `merge` type signature.\n' - + computation_types.type_mismatch_error_message( + + federated_language.framework.type_mismatch_error_message( merge.type_signature.parameter, expected_merge_param_type, - computation_types.TypeRelation.ASSIGNABLE, + federated_language.framework.TypeRelation.ASSIGNABLE, second_is_expected=True, ) ) @@ -169,18 +162,18 @@ def __init__( if up_to_merge.type_signature.parameter is not None: # TODO: b/147499373 - If None arguments were uniformly represented as # empty tuples, we could avoid this and related ugly if/else casing. - expected_after_merge_arg_type = computation_types.StructType([ + expected_after_merge_arg_type = federated_language.StructType([ (None, up_to_merge.type_signature.parameter), ( None, - computation_types.FederatedType( - merge.type_signature.result, placements.SERVER + federated_language.FederatedType( + merge.type_signature.result, federated_language.SERVER ), ), ]) else: - expected_after_merge_arg_type = computation_types.FederatedType( - merge.type_signature.result, placements.SERVER + expected_after_merge_arg_type = federated_language.FederatedType( + merge.type_signature.result, federated_language.SERVER ) after_merge.type_signature.parameter.check_assignable_from( @@ -188,24 +181,24 @@ def __init__( ) # pytype: disable=attribute-error def _federated_type_predicate( - type_signature: computation_types.Type, - placement: placements.PlacementLiteral, + type_signature: federated_language.Type, + placement: federated_language.framework.PlacementLiteral, ) -> bool: return ( - isinstance(type_signature, computation_types.FederatedType) + isinstance(type_signature, federated_language.FederatedType) and type_signature.placement is placement ) def _moves_clients_to_server_predicate( - intrinsic: building_blocks.Intrinsic, + intrinsic: federated_language.framework.Intrinsic, ): - parameter_contains_clients_placement = type_analysis.contains( + parameter_contains_clients_placement = federated_language.framework.type_contains( intrinsic.type_signature.parameter, # pytype: disable=attribute-error - lambda x: _federated_type_predicate(x, placements.CLIENTS), + lambda x: _federated_type_predicate(x, federated_language.CLIENTS), ) - result_contains_server_placement = type_analysis.contains( + result_contains_server_placement = federated_language.framework.type_contains( intrinsic.type_signature.result, # pytype: disable=attribute-error - lambda x: _federated_type_predicate(x, placements.SERVER), + lambda x: _federated_type_predicate(x, federated_language.SERVER), ) return ( parameter_contains_clients_placement @@ -215,11 +208,11 @@ def _moves_clients_to_server_predicate( aggregations = set() def _aggregation_predicate( - comp: building_blocks.ComputationBuildingBlock, + comp: federated_language.framework.ComputationBuildingBlock, ) -> bool: - if not isinstance(comp, building_blocks.Intrinsic): + if not isinstance(comp, federated_language.framework.Intrinsic): return False - if not isinstance(comp.type_signature, computation_types.FunctionType): + if not isinstance(comp.type_signature, federated_language.FunctionType): return False if _moves_clients_to_server_predicate(comp): aggregations.add((comp.uri, comp.type_signature)) @@ -230,7 +223,9 @@ def _aggregation_predicate( # computation.protos; to avoid opening up a visibility hole that isn't # technically necessary here, we prefer to simply skip the static check here # for computations which cannot convert themselves to building blocks. - if hasattr(after_merge, 'to_building_block') and tree_analysis.contains( + if hasattr( + after_merge, 'to_building_block' + ) and federated_language.framework.computation_contains( after_merge.to_building_block(), _aggregation_predicate ): formatted_aggregations = ', '.join( @@ -259,10 +254,10 @@ class _PartitioningValue: def _partition_value( - val: _PartitioningValue, type_signature: computation_types.Type + val: _PartitioningValue, type_signature: federated_language.Type ) -> _PartitioningValue: """Partitions value as specified in _split_value_into_subrounds.""" - if isinstance(type_signature, computation_types.StructType): + if isinstance(type_signature, federated_language.StructType): struct_val = structure.from_container(val.payload) partition_result: Optional[_PartitioningValue] = None result_container = [] @@ -287,8 +282,8 @@ def _partition_value( partition_result.last_client_index, ) elif ( - isinstance(type_signature, computation_types.FederatedType) - and type_signature.placement is placements.CLIENTS + isinstance(type_signature, federated_language.FederatedType) + and type_signature.placement is federated_language.CLIENTS ): if type_signature.all_equal: # In this case we simply replicate the argument for every subround. @@ -315,7 +310,7 @@ def _partition_value( def _split_value_into_subrounds( - value: Value, type_spec: computation_types.Type, num_desired_subrounds: int + value: Value, type_spec: federated_language.Type, num_desired_subrounds: int ) -> list[Value]: """Partitions clients-placed values to subrounds, replicating other values. @@ -334,26 +329,28 @@ def _split_value_into_subrounds( value: The argument to a computation intended to be invoked in subrounds, which will be partitioned. `value` can be any structure understood by TFF's native execution contexts. - type_spec: The `computation_types.Type` corresponding to `value`. + type_spec: The `federated_language.Type` corresponding to `value`. num_desired_subrounds: Int specifying the desired number of subrounds to run. Specifies the maximum length of the returned list. Returns: A list of partitioned values as described above. """ - cardinalities = cardinalities_utils.infer_cardinalities(value, type_spec) - if cardinalities.get(placements.CLIENTS) is None: + cardinalities = federated_language.framework.infer_cardinalities( + value, type_spec + ) + if cardinalities.get(federated_language.CLIENTS) is None: # The argument contains no clients-placed values, but may still perform # nontrivial clients-placed work. return [value for _ in range(num_desired_subrounds)] - elif cardinalities[placements.CLIENTS] == 0: + elif cardinalities[federated_language.CLIENTS] == 0: # Here the argument contains an empty clients-placed value; therefore this # computation should be run over an empty set of clients. return [value] partitioning_value = _PartitioningValue( payload=value, - num_remaining_clients=cardinalities[placements.CLIENTS], + num_remaining_clients=cardinalities[federated_language.CLIENTS], num_remaining_partitions=num_desired_subrounds, last_client_index=0, ) @@ -380,11 +377,11 @@ def _split_value_into_subrounds( def _repackage_partitioned_values( after_merge_results: Union[list[Value], tuple[Value, ...]], - result_type_spec: computation_types.Type, + result_type_spec: federated_language.Type, ) -> Value: """Inverts `_split_value_into_subrounds` above.""" py_typecheck.check_type(after_merge_results, (tuple, list)) - if isinstance(result_type_spec, computation_types.StructType): + if isinstance(result_type_spec, federated_language.StructType): after_merge_structs = [ structure.from_container(x) for x in after_merge_results ] @@ -400,8 +397,8 @@ def _repackage_partitioned_values( )) return structure.Struct(result_container) elif ( - isinstance(result_type_spec, computation_types.FederatedType) - and result_type_spec.placement is placements.CLIENTS + isinstance(result_type_spec, federated_language.FederatedType) + and result_type_spec.placement is federated_language.CLIENTS ): if result_type_spec.all_equal: return after_merge_results[0] @@ -413,16 +410,16 @@ def _repackage_partitioned_values( return after_merge_results[0] -class MergeableCompExecutionContextValue(typed_object.TypedObject): +class MergeableCompExecutionContextValue(federated_language.TypedObject): """Represents a value embedded in the `MergeableCompExecutionContext`.""" def __init__( self, value: object, - type_spec: computation_types.Type, + type_spec: federated_language.Type, num_desired_subrounds: int, ): - py_typecheck.check_type(type_spec, computation_types.Type) + py_typecheck.check_type(type_spec, federated_language.Type) self._type_signature = type_spec self._partitioned_value = _split_value_into_subrounds( value, self._type_signature, num_desired_subrounds=num_desired_subrounds @@ -437,7 +434,9 @@ def value(self): async def _invoke_up_to_merge_and_return_context( - comp: MergeableCompForm, arg, context: context_base.AsyncContext + comp: MergeableCompForm, + arg, + context: federated_language.framework.AsyncContext, ): return await context.invoke( comp.up_to_merge, # pytype: disable=attribute-error @@ -449,7 +448,7 @@ async def _merge_results( comp: MergeableCompForm, merge_partial, value_to_merge, - context: context_base.AsyncContext, + context: federated_language.framework.AsyncContext, ): return await context.invoke( comp.merge, # pytype: disable=attribute-error @@ -461,7 +460,7 @@ async def _compute_after_merged( comp: MergeableCompForm, original_arg, merge_result, - context: context_base.AsyncContext, + context: federated_language.framework.AsyncContext, ): if original_arg is not None: arg = structure.Struct.unnamed(original_arg, merge_result) @@ -474,14 +473,17 @@ async def _compute_after_merged( async def _run_in_async_context_pool( - task_fn: Callable[[object, context_base.AsyncContext], asyncio.Task], + task_fn: Callable[ + [object, federated_language.framework.AsyncContext], asyncio.Task + ], arg_list: Sequence[object], - execution_contexts: Sequence[context_base.AsyncContext], + execution_contexts: Sequence[federated_language.framework.AsyncContext], initial_result: object, postprocessing_hook: Callable[ - [object, object, context_base.AsyncContext], Awaitable[Value] + [object, object, federated_language.framework.AsyncContext], + Awaitable[Value], ], -) -> tuple[Value, Optional[context_base.AsyncContext]]: +) -> tuple[Value, Optional[federated_language.framework.AsyncContext]]: """Runs the tasks against the execution pool, sequentializing the extra work. Args: @@ -534,7 +536,7 @@ async def _run_in_async_context_pool( async def _invoke_merge_in_async_pool( comp: MergeableCompForm, arg_list: Sequence[object], - execution_contexts: Sequence[context_base.AsyncContext], + execution_contexts: Sequence[federated_language.framework.AsyncContext], ): """Invokes up to merge and merge in a pool of async contexts.""" @@ -561,7 +563,7 @@ async def _invoke_after_merge_in_async_pool( comp: MergeableCompForm, merge_result: object, arg_list: Sequence[object], - execution_contexts: Sequence[context_base.AsyncContext], + execution_contexts: Sequence[federated_language.framework.AsyncContext], ) -> list[object]: """Invokes after_merge in a pool of async contexts, returning result.""" @@ -586,7 +588,7 @@ async def postprocessing(result, partial_result, context): async def _invoke_mergeable_comp_form( comp: MergeableCompForm, arg: Optional[MergeableCompExecutionContextValue], - execution_contexts: Sequence[context_base.AsyncContext], + execution_contexts: Sequence[federated_language.framework.AsyncContext], ): """Invokes `comp` on `arg`, repackaging the results to a single value.""" @@ -599,13 +601,13 @@ async def _invoke_mergeable_comp_form( comp, arg_list, execution_contexts ) - def _predicate(type_spec: computation_types.Type) -> bool: + def _predicate(type_spec: federated_language.Type) -> bool: return ( - not isinstance(type_spec, computation_types.FederatedType) + not isinstance(type_spec, federated_language.FederatedType) or type_spec.all_equal ) - if type_analysis.contains_only( + if federated_language.framework.type_contains_only( comp.after_merge.type_signature.result, # pytype: disable=attribute-error _predicate, ): @@ -630,7 +632,7 @@ def _predicate(type_spec: computation_types.Type) -> bool: class MergeableCompExecutionContext( - context_base.SyncContext, Generic[_Computation] + federated_language.framework.SyncContext, Generic[_Computation] ): """Context which executes mergeable computations in subrounds. @@ -643,7 +645,7 @@ class MergeableCompExecutionContext( def __init__( self, - async_contexts: Sequence[context_base.AsyncContext], + async_contexts: Sequence[federated_language.framework.AsyncContext], compiler_fn: Optional[Callable[[_Computation], MergeableCompForm]] = None, transform_args: Optional[Callable[[object], object]] = None, transform_result: Optional[Callable[[object], object]] = None, @@ -670,7 +672,7 @@ def __init__( """ self._async_runner = async_utils.AsyncThreadRunner() for ctx in async_contexts: - py_typecheck.check_type(ctx, context_base.AsyncContext) + py_typecheck.check_type(ctx, federated_language.framework.AsyncContext) self._async_execution_contexts = async_contexts self._transform_args = transform_args self._transform_result = transform_result @@ -680,20 +682,22 @@ def __init__( else len(self._async_execution_contexts) ) if compiler_fn is not None: - self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn) + self._compiler_pipeline = federated_language.framework.CompilerPipeline( + compiler_fn + ) else: self._compiler_pipeline = None def invoke( self, - comp: Union[MergeableCompForm, computation_base.Computation], + comp: Union[MergeableCompForm, federated_language.framework.Computation], arg: Optional[object] = None, ): py_typecheck.check_type( - comp, (MergeableCompForm, computation_base.Computation) + comp, (MergeableCompForm, federated_language.framework.Computation) ) - if isinstance(comp, computation_base.Computation): + if isinstance(comp, federated_language.framework.Computation): if self._compiler_pipeline is None: raise ValueError( 'Without a compiler, mergeable comp execution context ' @@ -715,7 +719,7 @@ def invoke( self._num_subrounds, ) - result = type_conversions.type_to_py_container( + result = federated_language.framework.type_to_py_container( self._async_runner.run_coro_and_return_result( _invoke_mergeable_comp_form( comp, arg, self._async_execution_contexts diff --git a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context_test.py b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context_test.py index e7c6597ea7..39f34f0205 100644 --- a/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context_test.py +++ b/tensorflow_federated/python/core/impl/execution_contexts/mergeable_comp_execution_context_test.py @@ -15,20 +15,19 @@ import collections from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.impl.execution_contexts import mergeable_comp_execution_context -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements class PartitionValueTest(absltest.TestCase): def test_partitions_value_with_no_clients_arguments(self): value = 0 - type_signature = computation_types.FederatedType( - np.int32, placements.SERVER + type_signature = federated_language.FederatedType( + np.int32, federated_language.SERVER ) num_desired_subrounds = 2 partitioned_value = ( @@ -40,9 +39,19 @@ def test_partitions_value_with_no_clients_arguments(self): def test_wraps_value_with_empty_client_argument(self): value = (0, []) - type_signature = computation_types.StructType([ - (None, computation_types.FederatedType(np.int32, placements.SERVER)), - (None, computation_types.FederatedType(np.int32, placements.CLIENTS)), + type_signature = federated_language.StructType([ + ( + None, + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + ), + ( + None, + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), + ), ]) num_desired_subrounds = 2 partitioned_value = ( @@ -54,12 +63,17 @@ def test_wraps_value_with_empty_client_argument(self): def test_replicates_all_equal_clients_argument(self): value = (0, 1) - type_signature = computation_types.StructType([ - (None, computation_types.FederatedType(np.int32, placements.SERVER)), + type_signature = federated_language.StructType([ ( None, - computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + ), + ( + None, + federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ), ), ]) @@ -73,8 +87,8 @@ def test_replicates_all_equal_clients_argument(self): def test_partitions_client_placed_value_into_subrounds(self): value = list(range(10)) - type_signature = computation_types.FederatedType( - np.int32, placements.CLIENTS + type_signature = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) num_desired_subrounds = 5 partitioned_value = ( @@ -89,14 +103,18 @@ def test_partitions_clients_placed_struct_elem_into_subrounds(self): value = (0, list(range(10))) server_placed_name = 'a' clients_placed_name = 'b' - type_signature = computation_types.StructType([ + type_signature = federated_language.StructType([ ( server_placed_name, - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ), ( clients_placed_name, - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ), ]) @@ -116,8 +134,8 @@ def test_partitions_clients_placed_struct_elem_into_subrounds(self): def test_partitions_fewer_clients_than_rounds_into_nonempty_rounds(self): value = [0, 1] - type_signature = computation_types.FederatedType( - np.int32, placements.CLIENTS + type_signature = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) num_desired_subrounds = 5 partitioned_value = ( @@ -149,25 +167,36 @@ def assertRoundTripEqual( def test_roundtrip_with_no_clients_argument(self): value = 0 - type_signature = computation_types.FederatedType( - np.int32, placements.SERVER + type_signature = federated_language.FederatedType( + np.int32, federated_language.SERVER ) self.assertRoundTripEqual(value, type_signature, value) def test_roundtrip_with_named_struct(self): value = collections.OrderedDict(a=0) - type_signature = computation_types.StructType( - [('a', computation_types.FederatedType(np.int32, placements.SERVER))] - ) + type_signature = federated_language.StructType([( + 'a', + federated_language.FederatedType(np.int32, federated_language.SERVER), + )]) self.assertRoundTripEqual( value, type_signature, structure.Struct([('a', 0)]) ) def test_roundtrip_with_empty_clients_argument(self): value = (0, []) - type_signature = computation_types.StructType([ - (None, computation_types.FederatedType(np.int32, placements.SERVER)), - (None, computation_types.FederatedType(np.int32, placements.CLIENTS)), + type_signature = federated_language.StructType([ + ( + None, + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + ), + ( + None, + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), + ), ]) self.assertRoundTripEqual( value, type_signature, structure.from_container(value) @@ -175,26 +204,31 @@ def test_roundtrip_with_empty_clients_argument(self): def test_roundtrip_with_nonempty_clients_argument(self): value = list(range(10)) - type_signature = computation_types.FederatedType( - np.int32, placements.CLIENTS + type_signature = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) self.assertRoundTripEqual(value, type_signature, value) def test_roundtrip_with_nonempty_tuple_clients_argument(self): value = tuple(range(10)) - type_signature = computation_types.FederatedType( - np.int32, placements.CLIENTS + type_signature = federated_language.FederatedType( + np.int32, federated_language.CLIENTS ) self.assertRoundTripEqual(value, type_signature, value) def test_roundtrip_with_all_equal_clients_argument(self): value = (0, 1) - type_signature = computation_types.StructType([ - (None, computation_types.FederatedType(np.int32, placements.SERVER)), + type_signature = federated_language.StructType([ + ( + None, + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + ), ( None, - computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ), ), ]) diff --git a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py b/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py deleted file mode 100644 index 87e20148e5..0000000000 --- a/tensorflow_federated/python/core/impl/execution_contexts/sync_execution_context.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A context for execution based on an embedded executor instance.""" - -from collections.abc import Callable -from typing import Generic, Optional, TypeVar - -from tensorflow_federated.python.common_libs import async_utils -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.executors import cardinalities_utils -from tensorflow_federated.python.core.impl.executors import executor_factory - - -_Computation = TypeVar('_Computation', bound=computation_base.Computation) - - -class SyncExecutionContext(context_base.SyncContext, Generic[_Computation]): - """A synchronous execution context backed by an `executor_base.Executor`.""" - - def __init__( - self, - executor_fn: executor_factory.ExecutorFactory, - compiler_fn: Optional[Callable[[_Computation], object]] = None, - *, - transform_args: Optional[Callable[[object], object]] = None, - transform_result: Optional[Callable[[object], object]] = None, - cardinality_inference_fn: cardinalities_utils.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities, - ): - """Initializes a synchronous execution context which retries invocations. - - Args: - executor_fn: Instance of `executor_factory.ExecutorFactory`. - compiler_fn: A Python function that will be used to compile a computation. - transform_args: An `Optional` `Callable` used to transform the args before - they are passed to the computation. - transform_result: An `Optional` `Callable` used to transform the result - before it is returned. - cardinality_inference_fn: A Python function specifying how to infer - cardinalities from arguments (and their associated types). The value - returned by this function will be passed to the `create_executor` method - of `executor_fn` to construct a `tff.framework.Executor` instance. - """ - py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory) - self._executor_factory = executor_fn - self._async_context = async_execution_context.AsyncExecutionContext( - executor_fn=executor_fn, - compiler_fn=compiler_fn, - transform_args=transform_args, - transform_result=transform_result, - cardinality_inference_fn=cardinality_inference_fn, - ) - self._async_runner = async_utils.AsyncThreadRunner() - - @property - def executor_factory(self): - return self._executor_factory - - def invoke(self, comp, arg): - return self._async_runner.run_coro_and_return_result( - self._async_context.invoke(comp, arg) - ) diff --git a/tensorflow_federated/python/core/impl/executor_stacks/BUILD b/tensorflow_federated/python/core/impl/executor_stacks/BUILD index a97fcfc5e8..c189b2134c 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/BUILD +++ b/tensorflow_federated/python/core/impl/executor_stacks/BUILD @@ -34,11 +34,9 @@ py_library( ":executor_stack_bindings", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/impl/executors:cpp_to_python_executor", - "//tensorflow_federated/python/core/impl/executors:executor_base", "//tensorflow_federated/python/core/impl/executors:executor_bindings", - "//tensorflow_federated/python/core/impl/executors:executor_factory", "//tensorflow_federated/python/core/impl/executors:executors_errors", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -50,11 +48,9 @@ py_test( ], deps = [ ":cpp_executor_factory", - "//tensorflow_federated/python/core/impl/executors:executor_base", "//tensorflow_federated/python/core/impl/executors:executor_bindings", - "//tensorflow_federated/python/core/impl/executors:executor_factory", "//tensorflow_federated/python/core/impl/executors:executor_test_utils_bindings", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -65,10 +61,9 @@ py_library( tags = ["ignore_for_dep=third_party.py.IPython.get_ipython"], deps = [ ":python_executor_stacks", - "//tensorflow_federated/python/core/impl/executors:executor_factory", "//tensorflow_federated/python/core/impl/executors:remote_executor", "//tensorflow_federated/python/core/impl/executors:remote_executor_grpc_stub", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -83,7 +78,7 @@ py_library( "//tensorflow_federated/cc/core/impl/executor_stacks:executor_stack_bindings", "//tensorflow_federated/python/core/impl/executors:data_conversions", "//tensorflow_federated/python/core/impl/executors:executor_bindings", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -96,7 +91,7 @@ py_test( deps = [ ":executor_stack_bindings", "//tensorflow_federated/python/core/impl/executors:executor_bindings", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -106,8 +101,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/executors:executor_base", - "//tensorflow_federated/python/core/impl/executors:executor_factory", + "@federated_language//federated_language", ], ) @@ -118,8 +112,6 @@ py_test( shard_count = 5, deps = [ ":python_executor_stacks", - "//tensorflow_federated/python/core/impl/executors:executor_base", - "//tensorflow_federated/python/core/impl/executors:executor_factory", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory.py b/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory.py index 6c1dd687ea..ea31cc708b 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory.py @@ -20,15 +20,13 @@ from absl import logging import cachetools +import federated_language from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.impl.executor_stacks import executor_stack_bindings from tensorflow_federated.python.core.impl.executors import cpp_to_python_executor -from tensorflow_federated.python.core.impl.executors import executor_base from tensorflow_federated.python.core.impl.executors import executor_bindings -from tensorflow_federated.python.core.impl.executors import executor_factory from tensorflow_federated.python.core.impl.executors import executors_errors -from tensorflow_federated.python.core.impl.types import placements # Users likely do not intend to run 4 or more TensorFlow functions sequentially; # we special-case to warn users explicitly in this case, in addition to @@ -36,17 +34,20 @@ _CONCURRENCY_LEVEL_TO_WARN = 4 -def _get_hashable_key(cardinalities: executor_factory.CardinalitiesType): +def _get_hashable_key( + cardinalities: federated_language.framework.CardinalitiesType, +): return tuple(sorted((str(k), v) for k, v in cardinalities.items())) -class CPPExecutorFactory(executor_factory.ExecutorFactory): +class CPPExecutorFactory(federated_language.framework.ExecutorFactory): """An ExecutorFactory which wraps a simple executor_fn.""" def __init__( self, executor_fn: Callable[ - [executor_factory.CardinalitiesType], executor_bindings.Executor + [federated_language.framework.CardinalitiesType], + executor_bindings.Executor, ], executor_cache_size: int = 5, ): @@ -55,8 +56,8 @@ def __init__( self._executors = cachetools.LRUCache(self._cache_size) def create_executor( - self, cardinalities: executor_factory.CardinalitiesType - ) -> executor_base.Executor: + self, cardinalities: federated_language.framework.CardinalitiesType + ) -> federated_language.framework.Executor: cardinalities_key = _get_hashable_key(cardinalities) if cardinalities_key not in self._executors: cpp_executor = self._executor_fn(cardinalities) @@ -68,7 +69,7 @@ def create_executor( return self._executors[cardinalities_key] def clean_up_executor( - self, cardinalities: executor_factory.CardinalitiesType + self, cardinalities: federated_language.framework.CardinalitiesType ): cardinalities_key = _get_hashable_key(cardinalities) ex = self._executors.get(cardinalities_key) @@ -121,16 +122,16 @@ def local_cpp_executor_factory( client_leaf_executor_fn: Optional[ Callable[[int], executor_bindings.Executor] ] = None, -) -> executor_factory.ExecutorFactory: +) -> federated_language.framework.ExecutorFactory: """Local ExecutorFactory backed by C++ Executor bindings.""" _check_num_clients_is_valid(default_num_clients) def _executor_fn( - cardinalities: executor_factory.CardinalitiesType, + cardinalities: federated_language.framework.CardinalitiesType, ) -> executor_bindings.Executor: - if cardinalities.get(placements.CLIENTS) is None: - cardinalities[placements.CLIENTS] = default_num_clients - num_clients = cardinalities[placements.CLIENTS] + if cardinalities.get(federated_language.CLIENTS) is None: + cardinalities[federated_language.CLIENTS] = default_num_clients + num_clients = cardinalities[federated_language.CLIENTS] if ( max_concurrent_computation_calls > 0 and num_clients > max_concurrent_computation_calls @@ -189,15 +190,15 @@ def remote_cpp_executor_factory( default_num_clients: int = 0, stream_structs: bool = False, max_concurrent_computation_calls: int = -1, -) -> executor_factory.ExecutorFactory: +) -> federated_language.framework.ExecutorFactory: """ExecutorFactory backed by C++ Executor bindings.""" _check_num_clients_is_valid(default_num_clients) def _executor_fn( - cardinalities: executor_factory.CardinalitiesType, + cardinalities: federated_language.framework.CardinalitiesType, ) -> executor_bindings.Executor: - if cardinalities.get(placements.CLIENTS) is None: - cardinalities[placements.CLIENTS] = default_num_clients + if cardinalities.get(federated_language.CLIENTS) is None: + cardinalities[federated_language.CLIENTS] = default_num_clients try: if stream_structs: return executor_stack_bindings.create_streaming_remote_executor_stack( diff --git a/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory_test.py b/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory_test.py index f76f136f72..836cb3b6cb 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory_test.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/cpp_executor_factory_test.py @@ -13,13 +13,11 @@ # limitations under the License. from absl.testing import absltest +import federated_language from tensorflow_federated.python.core.impl.executor_stacks import cpp_executor_factory -from tensorflow_federated.python.core.impl.executors import executor_base from tensorflow_federated.python.core.impl.executors import executor_bindings -from tensorflow_federated.python.core.impl.executors import executor_factory from tensorflow_federated.python.core.impl.executors import executor_test_utils_bindings -from tensorflow_federated.python.core.impl.types import placements def _create_mock_execution_stack( @@ -43,16 +41,18 @@ def test_create_local_cpp_factory_constructs(self): local_cpp_factory = cpp_executor_factory.local_cpp_executor_factory( default_num_clients=0, leaf_executor_fn=_create_mock_execution_stack ) - self.assertIsInstance(local_cpp_factory, executor_factory.ExecutorFactory) + self.assertIsInstance( + local_cpp_factory, federated_language.framework.ExecutorFactory + ) def test_clean_up_executors_clears_state(self): local_cpp_factory = cpp_executor_factory.local_cpp_executor_factory( default_num_clients=0, leaf_executor_fn=_create_mock_execution_stack ) - cardinalities = {placements.CLIENTS: 1} + cardinalities = {federated_language.CLIENTS: 1} local_cpp_factory.create_executor(cardinalities) for executor in local_cpp_factory._executors.values(): - self.assertIsInstance(executor, executor_base.Executor) + self.assertIsInstance(executor, federated_language.framework.Executor) local_cpp_factory.clean_up_executor(cardinalities) self.assertEmpty(local_cpp_factory._executors) @@ -60,9 +60,13 @@ def test_create_local_cpp_factory_constructs_executor_implementation(self): local_cpp_factory = cpp_executor_factory.local_cpp_executor_factory( default_num_clients=0, leaf_executor_fn=_create_mock_execution_stack ) - self.assertIsInstance(local_cpp_factory, executor_factory.ExecutorFactory) - executor = local_cpp_factory.create_executor({placements.CLIENTS: 1}) - self.assertIsInstance(executor, executor_base.Executor) + self.assertIsInstance( + local_cpp_factory, federated_language.framework.ExecutorFactory + ) + executor = local_cpp_factory.create_executor( + {federated_language.CLIENTS: 1} + ) + self.assertIsInstance(executor, federated_language.framework.Executor) def test_create_remote_cpp_factory_constructs(self): targets = ['localhost:8000', 'localhost:8001'] @@ -72,7 +76,9 @@ def test_create_remote_cpp_factory_constructs(self): remote_cpp_factory = cpp_executor_factory.remote_cpp_executor_factory( channels=channels, default_num_clients=0 ) - self.assertIsInstance(remote_cpp_factory, executor_factory.ExecutorFactory) + self.assertIsInstance( + remote_cpp_factory, federated_language.framework.ExecutorFactory + ) def test_create_remote_cpp_factory_raises_with_no_available_workers(self): targets = ['localhost:8000', 'localhost:8001'] @@ -82,9 +88,11 @@ def test_create_remote_cpp_factory_raises_with_no_available_workers(self): remote_cpp_factory = cpp_executor_factory.remote_cpp_executor_factory( channels=channels, default_num_clients=0 ) - self.assertIsInstance(remote_cpp_factory, executor_factory.ExecutorFactory) + self.assertIsInstance( + remote_cpp_factory, federated_language.framework.ExecutorFactory + ) with self.assertRaises(Exception): - remote_cpp_factory.create_executor({placements.CLIENTS: 1}) + remote_cpp_factory.create_executor({federated_language.CLIENTS: 1}) def test_create_cpp_factory_raises_with_invalid_default_num_clients(self): with self.subTest('local_nonnegative'): diff --git a/tensorflow_federated/python/core/impl/executor_stacks/executor_factory.py b/tensorflow_federated/python/core/impl/executor_stacks/executor_factory.py index 19a9977c74..3589ded023 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/executor_factory.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/executor_factory.py @@ -21,14 +21,13 @@ import time from absl import logging +import federated_language import grpc import portpicker from tensorflow_federated.python.core.impl.executor_stacks import python_executor_stacks -from tensorflow_federated.python.core.impl.executors import executor_factory from tensorflow_federated.python.core.impl.executors import remote_executor from tensorflow_federated.python.core.impl.executors import remote_executor_grpc_stub -from tensorflow_federated.python.core.impl.types import placements _LOCALHOST_SERVER_WAIT_TIME_SEC = 1.0 @@ -45,7 +44,7 @@ def local_cpp_executor_factory( default_num_clients: int = 0, max_concurrent_computation_calls: int = -1, stream_structs: bool = False, -) -> executor_factory.ExecutorFactory: +) -> federated_language.framework.ExecutorFactory: """Returns an execution context backed by C++ runtime. Args: @@ -150,8 +149,8 @@ def get_stub(self) -> remote_executor_grpc_stub.RemoteExecutorGrpcStub: service_manager = ServiceManager() def stack_fn(cardinalities): - if cardinalities.get(placements.CLIENTS) is None: - cardinalities[placements.CLIENTS] = default_num_clients + if cardinalities.get(federated_language.CLIENTS) is None: + cardinalities[federated_language.CLIENTS] = default_num_clients stub = service_manager.get_stub() ex = remote_executor.RemoteExecutor(stub, stream_structs=stream_structs) ex.set_cardinalities(cardinalities) diff --git a/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py b/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py index 624199e43e..510493b16b 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings.py @@ -15,10 +15,11 @@ from collections.abc import Mapping, Sequence +import federated_language + from tensorflow_federated.cc.core.impl.executor_stacks import executor_stack_bindings from tensorflow_federated.python.core.impl.executors import data_conversions from tensorflow_federated.python.core.impl.executors import executor_bindings -from tensorflow_federated.python.core.impl.types import placements def filter_to_live_channels( @@ -33,7 +34,7 @@ def filter_to_live_channels( def create_remote_executor_stack( channels: Sequence[executor_bindings.GRPCChannel], - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], max_concurrent_computation_calls: int = -1, ) -> executor_bindings.Executor: """Constructs a RemoteExecutor proxying services on `targets`.""" @@ -47,7 +48,7 @@ def create_remote_executor_stack( def create_streaming_remote_executor_stack( channels: Sequence[executor_bindings.GRPCChannel], - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> executor_bindings.Executor: """Constructs a RemoteExecutor proxying services on `targets`.""" uri_cardinalities = ( diff --git a/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings_test.py b/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings_test.py index d45b9cc4d3..e6bd08b69a 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings_test.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/executor_stack_bindings_test.py @@ -14,14 +14,14 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.core.impl.executor_stacks import executor_stack_bindings from tensorflow_federated.python.core.impl.executors import executor_bindings -from tensorflow_federated.python.core.impl.types import placements _TARGET_LIST = ['localhost:8000', 'localhost:8001'] -_CARDINALITIES = {placements.CLIENTS: 5} +_CARDINALITIES = {federated_language.CLIENTS: 5} class ExecutorStackBindingsTest(parameterized.TestCase): diff --git a/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks.py b/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks.py index 01603d45c2..4c05872e16 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks.py @@ -16,10 +16,9 @@ from collections.abc import Callable import cachetools +import federated_language from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.executors import executor_base -from tensorflow_federated.python.core.impl.executors import executor_factory # Place a limit on the maximum size of the executor caches managed by the # ExecutorFactories, to prevent unbounded thread and memory growth in the case @@ -27,37 +26,43 @@ _EXECUTOR_CACHE_SIZE = 10 -def _get_hashable_key(cardinalities: executor_factory.CardinalitiesType): +def _get_hashable_key( + cardinalities: federated_language.framework.CardinalitiesType, +): return tuple(sorted((str(k), v) for k, v in cardinalities.items())) -class ResourceManagingExecutorFactory(executor_factory.ExecutorFactory): +class ResourceManagingExecutorFactory( + federated_language.framework.ExecutorFactory +): """Implementation of executor factory holding an executor per cardinality.""" def __init__( self, executor_stack_fn: Callable[ - [executor_factory.CardinalitiesType], executor_base.Executor + [federated_language.framework.CardinalitiesType], + federated_language.framework.Executor, ], ): """Initializes `ResourceManagingExecutorFactory`. `ResourceManagingExecutorFactory` manages a mapping from `cardinalities` - to `executor_base.Executors`, closing and destroying the executors in this + to `federated_language.framework.Executors`, closing and destroying the + executors in this mapping when asked. Args: executor_stack_fn: Callable taking a mapping from - `placements.PlacementLiteral` to integers, and returning an - `executor_base.Executor`. The returned executor will be configured to - handle these cardinalities. + `federated_language.framework.PlacementLiteral` to integers, and + returning an `federated_language.framework.Executor`. The returned + executor will be configured to handle these cardinalities. """ self._executor_stack_fn = executor_stack_fn self._executors = cachetools.LRUCache(_EXECUTOR_CACHE_SIZE) def create_executor( - self, cardinalities: executor_factory.CardinalitiesType - ) -> executor_base.Executor: + self, cardinalities: federated_language.framework.CardinalitiesType + ) -> federated_language.framework.Executor: """Constructs or gets existing executor. Returns a previously-constructed executor if this method has already been @@ -65,13 +70,14 @@ def create_executor( with `cardinalities` and returns the result. Args: - cardinalities: `dict` with `placements.PlacementLiteral` keys and integer - values, specifying the population size at each placement. The executor - stacks returned from this method are not themselves polymorphic; a - concrete stack must have fixed sizes at each placement. + cardinalities: `dict` with `federated_language.framework.PlacementLiteral` + keys and integer values, specifying the population size at each + placement. The executor stacks returned from this method are not + themselves polymorphic; a concrete stack must have fixed sizes at each + placement. Returns: - Instance of `executor_base.Executor` as described above. + Instance of `federated_language.framework.Executor` as described above. """ py_typecheck.check_type(cardinalities, dict) key = _get_hashable_key(cardinalities) @@ -79,12 +85,12 @@ def create_executor( if ex is not None: return ex ex = self._executor_stack_fn(cardinalities) - py_typecheck.check_type(ex, executor_base.Executor) + py_typecheck.check_type(ex, federated_language.framework.Executor) self._executors[key] = ex return ex def clean_up_executor( - self, cardinalities: executor_factory.CardinalitiesType + self, cardinalities: federated_language.framework.CardinalitiesType ): """Calls `close` on constructed executors, resetting internal cache. diff --git a/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks_test.py b/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks_test.py index 8d84188313..6743ef6a9a 100644 --- a/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks_test.py +++ b/tensorflow_federated/python/core/impl/executor_stacks/python_executor_stacks_test.py @@ -16,14 +16,12 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language from tensorflow_federated.python.core.impl.executor_stacks import python_executor_stacks -from tensorflow_federated.python.core.impl.executors import executor_base -from tensorflow_federated.python.core.impl.executors import executor_factory -from tensorflow_federated.python.core.impl.types import placements -class ExecutorMock(mock.MagicMock, executor_base.Executor): +class ExecutorMock(mock.MagicMock, federated_language.framework.Executor): async def create_value(self, *args): pass @@ -44,7 +42,8 @@ async def close(self, *args): class ConcreteExecutorFactoryTest(parameterized.TestCase): def test_subclass_base_fails_no_create_method(self): - class NotCallable(executor_factory.ExecutorFactory): + + class NotCallable(federated_language.framework.ExecutorFactory): def clean_up_executor(self, x): pass @@ -53,7 +52,8 @@ def clean_up_executor(self, x): NotCallable() def test_subclass_base_fails_no_cleanup(self): - class NoCleanup(executor_factory.ExecutorFactory): + + class NoCleanup(federated_language.framework.ExecutorFactory): def create_executor(self, x): pass @@ -62,7 +62,8 @@ def create_executor(self, x): NoCleanup() def test_instantiation_succeeds_both_methods_specified(self): - class Fine(executor_factory.ExecutorFactory): + + class Fine(federated_language.framework.ExecutorFactory): def create_executor(self, x): pass @@ -95,7 +96,7 @@ def _stack_fn(x): factory = ex_factory(_stack_fn) ex = factory.create_executor({}) - self.assertIsInstance(ex, executor_base.Executor) + self.assertIsInstance(ex, federated_language.framework.Executor) @parameterized.named_parameters(( 'ResourceManagingExecutorFactory', @@ -107,7 +108,7 @@ def _stack_fn(x): return ExecutorMock() factory = ex_factory(_stack_fn) - factory.clean_up_executor({placements.CLIENTS: 1}) + factory.clean_up_executor({federated_language.CLIENTS: 1}) @parameterized.named_parameters(( 'ResourceManagingExecutorFactory', @@ -146,7 +147,7 @@ def _stack_fn(x): factory = ex_factory(_stack_fn) for _ in range(2): factory.create_executor({}) - factory.create_executor({placements.SERVER: 1}) + factory.create_executor({federated_language.SERVER: 1}) self.assertEqual(num_times_invoked, 2) def test_executors_persisted_is_capped(self): @@ -156,7 +157,7 @@ def test_executors_persisted_is_capped(self): lambda _: ex ) for num_clients in range(100): - factory.create_executor({placements.CLIENTS: num_clients}) + factory.create_executor({federated_language.CLIENTS: num_clients}) self.assertLess(len(factory._executors), 20) diff --git a/tensorflow_federated/python/core/impl/executors/BUILD b/tensorflow_federated/python/core/impl/executors/BUILD index 6166c1a822..bca7088677 100644 --- a/tensorflow_federated/python/core/impl/executors/BUILD +++ b/tensorflow_federated/python/core/impl/executors/BUILD @@ -24,29 +24,6 @@ py_library( visibility = ["//tools/python_package:python_package_tool"], ) -py_library( - name = "cardinalities_utils", - srcs = ["cardinalities_utils.py"], - srcs_version = "PY3", - deps = [ - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_test( - name = "cardinalities_utils_test", - srcs = ["cardinalities_utils_test.py"], - deps = [ - ":cardinalities_utils", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - py_library( name = "cpp_to_python_executor", srcs = ["cpp_to_python_executor.py"], @@ -54,14 +31,11 @@ py_library( "nokokoro", # b/193543632: C++ execution is not fully supported in OSS. ], deps = [ - ":executor_base", ":executor_bindings", - ":executor_value_base", ":executors_errors", ":value_serialization", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/common_libs:tracing", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -76,7 +50,7 @@ py_test( ":executor_bindings", ":value_serialization", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -84,7 +58,7 @@ py_library( name = "data_conversions", srcs = ["data_conversions.py"], srcs_version = "PY3", - deps = ["//tensorflow_federated/python/core/impl/types:placements"], + deps = ["@federated_language//federated_language"], ) py_test( @@ -92,17 +66,10 @@ py_test( srcs = ["data_conversions_test.py"], deps = [ ":data_conversions", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) -py_library( - name = "executor_base", - srcs = ["executor_base.py"], - srcs_version = "PY3", - deps = [":executor_value_base"], -) - py_library( name = "executor_bindings", srcs = ["executor_bindings.py"], @@ -113,7 +80,7 @@ py_library( deps = [ ":data_conversions", "//tensorflow_federated/cc/core/impl/executors:executor_bindings", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -127,17 +94,7 @@ py_test( deps = [ ":executor_bindings", ":executor_test_utils_bindings", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "executor_factory", - srcs = ["executor_factory.py"], - srcs_version = "PY3", - deps = [ - ":executor_base", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -157,8 +114,7 @@ py_library( srcs_version = "PY3", deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:typed_object", + "@federated_language//federated_language", ], ) @@ -167,22 +123,15 @@ py_test( srcs = ["executor_utils_test.py"], deps = [ ":executor_utils", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) -py_library( - name = "executor_value_base", - srcs = ["executor_value_base.py"], - srcs_version = "PY3", - deps = ["//tensorflow_federated/python/core/impl/types:typed_object"], -) - py_library( name = "executors_errors", srcs = ["executors_errors.py"], srcs_version = "PY3", + deps = ["@federated_language//federated_language"], ) py_library( @@ -190,17 +139,13 @@ py_library( srcs = ["remote_executor.py"], srcs_version = "PY3", deps = [ - ":executor_base", - ":executor_value_base", ":executors_errors", ":remote_executor_stub", ":value_serialization", "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/common_libs:tracing", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -213,10 +158,7 @@ py_test( ":remote_executor_stub", "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -236,7 +178,7 @@ py_library( ":remote_executor_stub", "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/proto/v0:executor_py_pb2_grpc", - "//tensorflow_federated/python/common_libs:tracing", + "@federated_language//federated_language", ], ) @@ -249,8 +191,7 @@ py_test( ":value_serialization", "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/proto/v0:executor_py_pb2_grpc", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", + "@federated_language//federated_language", ], ) @@ -260,21 +201,12 @@ py_library( srcs_version = "PY3", deps = [ ":executor_utils", - "//tensorflow_federated/proto/v0:array_py_pb2", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/common_libs:tracing", - "//tensorflow_federated/python/core/impl/compiler:array", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:dtype_utils", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:array_py_pb2", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -284,13 +216,9 @@ py_test( srcs = ["value_serialization_test.py"], deps = [ ":value_serialization", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/proto/v0:executor_py_pb2", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_serialization", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) diff --git a/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py b/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py deleted file mode 100644 index d4154ee421..0000000000 --- a/tensorflow_federated/python/core/impl/executors/cardinalities_utils.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for cardinality inference and handling.""" - -from collections.abc import Callable, Mapping - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -def merge_cardinalities(existing, to_add): - """Merges dicts `existing` and `to_add`, checking for conflicts.""" - py_typecheck.check_type(existing, dict) - py_typecheck.check_type(to_add, dict) - for key, val in existing.items(): - py_typecheck.check_type(key, placements.PlacementLiteral) - py_typecheck.check_type(val, int) - if not to_add: - return existing - elif not existing: - return to_add - cardinalities = {} - cardinalities.update(existing) - for key, val in to_add.items(): - py_typecheck.check_type(key, placements.PlacementLiteral) - py_typecheck.check_type(val, int) - if key not in cardinalities: - cardinalities[key] = val - elif cardinalities[key] != val: - raise ValueError( - 'Conflicting cardinalities for {}: {} vs {}'.format( - key, val, cardinalities[key] - ) - ) - return cardinalities - - -class InvalidNonAllEqualValueError(TypeError): - - def __init__(self, value, type_spec): - message = ( - f'Expected non-all-equal value with placement {type_spec.placement} ' - 'to be a `list` or `tuple`, found a value of Python type ' - f'{type(value)}:\n{value}' - ) - super().__init__(message) - - -# We define this type here to avoid having to redeclare it wherever we -# parameterize by a cardinality inference fn. -CardinalityInferenceFnType = Callable[ - [object, computation_types.Type], Mapping[placements.PlacementLiteral, int] -] - - -def infer_cardinalities(value, type_spec): - """Infers cardinalities from Python `value`. - - Allows for any Python object to represent a federated value; enforcing - particular representations is not the job of this inference function, but - rather ingestion functions lower in the stack. - - Args: - value: Python object from which to infer TFF placement cardinalities. - type_spec: The TFF type spec for `value`, determining the semantics for - inferring cardinalities. That is, we only pull the cardinality off of - federated types. - - Returns: - Dict of cardinalities. - - Raises: - ValueError: If conflicting cardinalities are inferred from `value`. - TypeError: If the arguments are of the wrong types, or if `type_spec` is - a federated type which is not `all_equal` but the yet-to-be-embedded - `value` is not represented as a Python `list`. - """ - if value is None: - return {} - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.FederatedType): - if type_spec.all_equal: - return {} - if not isinstance(value, (list, tuple)): - raise InvalidNonAllEqualValueError(value, type_spec) - return {type_spec.placement: len(value)} - elif isinstance(type_spec, computation_types.StructType): - structure_value = structure.from_container(value, recursive=False) - cardinality_dict = {} - for idx, (_, elem_type) in enumerate(type_spec.items()): - cardinality_dict = merge_cardinalities( - cardinality_dict, infer_cardinalities(structure_value[idx], elem_type) - ) - return cardinality_dict - else: - return {} diff --git a/tensorflow_federated/python/core/impl/executors/cardinalities_utils_test.py b/tensorflow_federated/python/core/impl/executors/cardinalities_utils_test.py deleted file mode 100644 index aa6557084c..0000000000 --- a/tensorflow_federated/python/core/impl/executors/cardinalities_utils_test.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.executors import cardinalities_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -class InferCardinalitiesTest(absltest.TestCase): - - def test_returns_empty_dict_none_value(self): - type_signature = computation_types.TensorType(np.int32) - self.assertEqual( - cardinalities_utils.infer_cardinalities(None, type_signature), {} - ) - - def test_raises_none_type(self): - with self.assertRaises(TypeError): - cardinalities_utils.infer_cardinalities(1, None) - - def test_noops_on_int(self): - type_signature = computation_types.TensorType(np.int32) - cardinalities = cardinalities_utils.infer_cardinalities(1, type_signature) - self.assertEmpty(cardinalities) - - def test_raises_federated_type_integer(self): - federated_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ) - with self.assertRaises(TypeError): - cardinalities_utils.infer_cardinalities(1, federated_type) - - def test_raises_federated_type_generator(self): - def generator_fn(): - yield 1 - - generator = generator_fn() - federated_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ) - with self.assertRaises(TypeError): - cardinalities_utils.infer_cardinalities(generator, federated_type) - - def test_passes_federated_type_tuple(self): - tup = tuple(range(5)) - federated_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ) - cardinalities_utils.infer_cardinalities(tup, federated_type) - five_client_cardinalities = cardinalities_utils.infer_cardinalities( - tup, federated_type - ) - self.assertEqual(five_client_cardinalities[placements.CLIENTS], 5) - - def test_adds_list_length_as_cardinality_at_clients(self): - federated_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ) - five_clients = list(range(5)) - five_client_cardinalities = cardinalities_utils.infer_cardinalities( - five_clients, federated_type - ) - self.assertEqual(five_client_cardinalities[placements.CLIENTS], 5) - - def test_raises_conflicting_clients_sizes(self): - federated_type = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ) - five_clients = list(range(5)) - ten_clients = list(range(10)) - tuple_of_federated_types = computation_types.StructType( - [federated_type, federated_type] - ) - with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): - cardinalities_utils.infer_cardinalities( - [five_clients, ten_clients], tuple_of_federated_types - ) - - def test_adds_list_length_as_cardinality_at_new_placement(self): - new_placement = placements.PlacementLiteral( - 'Agg', 'Agg', False, 'Intermediate aggregators' - ) - federated_type = computation_types.FederatedType( - np.int32, new_placement, all_equal=False - ) - ten_aggregators = list(range(10)) - ten_aggregator_cardinalities = cardinalities_utils.infer_cardinalities( - ten_aggregators, federated_type - ) - self.assertEqual(ten_aggregator_cardinalities[new_placement], 10) - - def test_recurses_under_tuple_type(self): - client_int = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ) - new_placement = placements.PlacementLiteral( - 'Agg', 'Agg', False, 'Intermediate aggregators' - ) - aggregator_placed_int = computation_types.FederatedType( - np.int32, new_placement, all_equal=False - ) - five_aggregators = list(range(5)) - ten_clients = list(range(10)) - mixed_cardinalities = cardinalities_utils.infer_cardinalities( - [ten_clients, five_aggregators], - computation_types.StructType([client_int, aggregator_placed_int]), - ) - self.assertEqual(mixed_cardinalities[placements.CLIENTS], 10) - self.assertEqual(mixed_cardinalities[new_placement], 5) - - def test_infer_cardinalities_success_structure(self): - foo = cardinalities_utils.infer_cardinalities( - structure.Struct([ - ('A', [1, 2, 3]), - ( - 'B', - structure.Struct([ - ('C', [[1, 2], [3, 4], [5, 6]]), - ('D', [True, False, True]), - ]), - ), - ]), - computation_types.StructType([ - ( - 'A', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ( - 'B', - [ - ( - 'C', - computation_types.FederatedType( - computation_types.SequenceType(np.int32), - placements.CLIENTS, - ), - ), - ( - 'D', - computation_types.FederatedType( - np.bool_, placements.CLIENTS - ), - ), - ], - ), - ]), - ) - self.assertDictEqual(foo, {placements.CLIENTS: 3}) - - def test_infer_cardinalities_structure_failure(self): - with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): - cardinalities_utils.infer_cardinalities( - structure.Struct([('A', [1, 2, 3]), ('B', [1, 2])]), - computation_types.StructType([ - ( - 'A', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ( - 'B', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]), - ) - - def test_raises_invalid_non_all_equal_value_error(self): - with self.assertRaises(cardinalities_utils.InvalidNonAllEqualValueError): - cardinalities_utils.infer_cardinalities( - 5, - computation_types.FederatedType( - computation_types.TensorType(np.int32), placements.CLIENTS - ), - ) - - -class MergeCardinalitiesTest(absltest.TestCase): - - def test_raises_non_dict_arg(self): - with self.assertRaises(TypeError): - cardinalities_utils.merge_cardinalities({}, 1) - - def test_raises_non_placement_keyed_dict(self): - with self.assertRaises(TypeError): - cardinalities_utils.merge_cardinalities( - {'a': 1}, {placements.CLIENTS: 10} - ) - with self.assertRaises(TypeError): - cardinalities_utils.merge_cardinalities( - {placements.CLIENTS: 10}, {'a': 1} - ) - - def test_raises_merge_conflicting_cardinalities(self): - with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): - cardinalities_utils.merge_cardinalities( - {placements.CLIENTS: 10}, {placements.CLIENTS: 11} - ) - - def test_noops_no_conflict(self): - clients_placed_cardinality = {placements.CLIENTS: 10} - noop = cardinalities_utils.merge_cardinalities( - clients_placed_cardinality, clients_placed_cardinality - ) - self.assertEqual(noop, clients_placed_cardinality) - - def test_merges_different_placements(self): - clients_placed_cardinality = {placements.CLIENTS: 10} - server_placed_cardinality = {placements.SERVER: 1} - merged = cardinalities_utils.merge_cardinalities( - clients_placed_cardinality, server_placed_cardinality - ) - self.assertEqual(merged, {placements.CLIENTS: 10, placements.SERVER: 1}) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py b/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py index c2ec5ed8d6..b156dc177a 100644 --- a/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py +++ b/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor.py @@ -18,14 +18,12 @@ import concurrent from typing import NoReturn, Optional +import federated_language + from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.common_libs import tracing -from tensorflow_federated.python.core.impl.executors import executor_base from tensorflow_federated.python.core.impl.executors import executor_bindings -from tensorflow_federated.python.core.impl.executors import executor_value_base from tensorflow_federated.python.core.impl.executors import executors_errors from tensorflow_federated.python.core.impl.executors import value_serialization -from tensorflow_federated.python.core.impl.types import computation_types def _handle_error(exception: Exception) -> NoReturn: @@ -35,7 +33,7 @@ def _handle_error(exception: Exception) -> NoReturn: raise exception -class CppToPythonExecutorValue(executor_value_base.ExecutorValue): +class CppToPythonExecutorValue(federated_language.framework.ExecutorValue): """ExecutorValue representation of values embedded in C++ executors. Instances of this class represent ownership of the resources which back @@ -48,7 +46,7 @@ class CppToPythonExecutorValue(executor_value_base.ExecutorValue): def __init__( self, owned_value_id: executor_bindings.OwnedValueId, - type_signature: computation_types.Type, + type_signature: federated_language.Type, cpp_executor: executor_bindings.Executor, futures_executor: concurrent.futures.Executor, ): @@ -58,14 +56,14 @@ def __init__( self._futures_executor = futures_executor @property - def type_signature(self) -> computation_types.Type: + def type_signature(self) -> federated_language.Type: return self._type_signature @property def reference(self) -> int: return self._owned_value_id.ref - @tracing.trace + @federated_language.framework.trace async def compute(self) -> object: """Pulls protocol buffer out of C++ into Python, and deserializes.""" running_loop = asyncio.get_running_loop() @@ -85,7 +83,7 @@ def _materialize(): return deserialized_value -class CppToPythonExecutorBridge(executor_base.Executor): +class CppToPythonExecutorBridge(federated_language.framework.Executor): """Implementation of Python executor interface in terms of C++ executor. This class implements a thin layer integrating the @@ -105,9 +103,9 @@ def __init__( self._cpp_executor = cpp_executor self._futures_executor = futures_executor - @tracing.trace + @federated_language.framework.trace async def create_value( - self, value: object, type_signature: computation_types.Type + self, value: object, type_signature: federated_language.Type ) -> CppToPythonExecutorValue: serialized_value, _ = value_serialization.serialize_value( value, type_signature @@ -120,7 +118,7 @@ async def create_value( owned_id, type_signature, self._cpp_executor, self._futures_executor ) - @tracing.trace + @federated_language.framework.trace async def create_call( self, fn: CppToPythonExecutorValue, @@ -142,7 +140,7 @@ async def create_call( self._futures_executor, ) - @tracing.trace + @federated_language.framework.trace async def create_struct( self, elements: Sequence[CppToPythonExecutorValue] ) -> CppToPythonExecutorValue: @@ -158,12 +156,12 @@ async def create_struct( _handle_error(e) return CppToPythonExecutorValue( struct_id, - computation_types.StructType(type_list), + federated_language.StructType(type_list), self._cpp_executor, self._futures_executor, ) - @tracing.trace + @federated_language.framework.trace async def create_selection( self, source: CppToPythonExecutorValue, index: int ) -> CppToPythonExecutorValue: diff --git a/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor_test.py b/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor_test.py index 7b420eb5e7..14773289a9 100644 --- a/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor_test.py +++ b/tensorflow_federated/python/core/impl/executors/cpp_to_python_executor_test.py @@ -16,13 +16,13 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.impl.executors import cpp_to_python_executor from tensorflow_federated.python.core.impl.executors import executor_bindings from tensorflow_federated.python.core.impl.executors import value_serialization -from tensorflow_federated.python.core.impl.types import computation_types class CcToPythonExecutorTest( @@ -40,17 +40,17 @@ def setUp(self): ) @parameterized.named_parameters( - ('integer', 1, computation_types.TensorType(np.int32)), - ('float', 1.0, computation_types.TensorType(np.float32)), + ('integer', 1, federated_language.TensorType(np.int32)), + ('float', 1.0, federated_language.TensorType(np.float32)), ( 'mixed_structure', structure.Struct.unnamed(0, 1.0), - computation_types.StructType([np.int32, np.float32]), + federated_language.StructType([np.int32, np.float32]), ), ( 'nested_structure', structure.Struct.unnamed(0, structure.Struct.unnamed(1, 2)), - computation_types.StructType([np.int32, [np.int32, np.int32]]), + federated_language.StructType([np.int32, [np.int32, np.int32]]), ), ) async def test_create_value(self, value, type_spec): @@ -66,8 +66,8 @@ async def test_create_call_tensorflow_function_noarg(self): fn = unittest.mock.create_autospec( cpp_to_python_executor.CppToPythonExecutorValue ) - fn.type_signature = computation_types.FunctionType( - None, computation_types.TensorType(np.int32) + fn.type_signature = federated_language.FunctionType( + None, federated_language.TensorType(np.int32) ) fn.reference = 1 @@ -84,12 +84,12 @@ async def test_create_call_tensorflow_function_with_arg(self): fn = unittest.mock.create_autospec( cpp_to_python_executor.CppToPythonExecutorValue ) - fn.type_signature = computation_types.FunctionType(None, np.int32) + fn.type_signature = federated_language.FunctionType(None, np.int32) fn.reference = 1 arg = unittest.mock.create_autospec( cpp_to_python_executor.CppToPythonExecutorValue ) - arg.type_signature = computation_types.TensorType(np.int32) + arg.type_signature = federated_language.TensorType(np.int32) arg.reference = 2 constructed_call = await self._test_executor.create_call(fn, arg) @@ -105,7 +105,7 @@ async def test_create_struct(self): struct_element = unittest.mock.create_autospec( cpp_to_python_executor.CppToPythonExecutorValue ) - struct_element.type_signature = computation_types.TensorType(np.int32) + struct_element.type_signature = federated_language.TensorType(np.int32) struct_element.reference = 1 constructed_struct = await self._test_executor.create_struct( @@ -123,7 +123,7 @@ async def test_create_selection(self): source = unittest.mock.create_autospec( cpp_to_python_executor.CppToPythonExecutorValue ) - source.type_signature = computation_types.StructType([np.int32]) + source.type_signature = federated_language.StructType([np.int32]) source.reference = 1 selected_element = await self._test_executor.create_selection(source, 0) @@ -135,10 +135,10 @@ async def test_compute(self): owned_id = unittest.mock.create_autospec(executor_bindings.OwnedValueId) owned_id.ref = 1 serialized_two, _ = value_serialization.serialize_value( - 2, computation_types.TensorType(np.int32) + 2, federated_language.TensorType(np.int32) ) self._mock_executor.materialize.return_value = serialized_two - type_signature = computation_types.TensorType(np.int32) + type_signature = federated_language.TensorType(np.int32) executor_value = cpp_to_python_executor.CppToPythonExecutorValue( owned_id, type_signature, diff --git a/tensorflow_federated/python/core/impl/executors/data_conversions.py b/tensorflow_federated/python/core/impl/executors/data_conversions.py index 910994ebe9..b9de95d9f2 100644 --- a/tensorflow_federated/python/core/impl/executors/data_conversions.py +++ b/tensorflow_federated/python/core/impl/executors/data_conversions.py @@ -15,11 +15,11 @@ from collections.abc import Mapping -from tensorflow_federated.python.core.impl.types import placements +import federated_language def convert_cardinalities_dict_to_string_keyed( - cardinalities: Mapping[placements.PlacementLiteral, int] + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> Mapping[str, int]: """Ensures incoming cardinalities dict is formatted correctly.""" if not isinstance(cardinalities, Mapping): @@ -29,7 +29,7 @@ def convert_cardinalities_dict_to_string_keyed( ) uri_cardinalities = {} for placement, cardinality in cardinalities.items(): - if not isinstance(placement, placements.PlacementLiteral): + if not isinstance(placement, federated_language.framework.PlacementLiteral): raise TypeError( '`cardinalities` must be a `Mapping` with ' '`PlacementLiteral` (e.g. `tff.CLIENTS`) keys. ' diff --git a/tensorflow_federated/python/core/impl/executors/data_conversions_test.py b/tensorflow_federated/python/core/impl/executors/data_conversions_test.py index 0979f5b443..c444de5cdc 100644 --- a/tensorflow_federated/python/core/impl/executors/data_conversions_test.py +++ b/tensorflow_federated/python/core/impl/executors/data_conversions_test.py @@ -13,9 +13,9 @@ # limitations under the License. from absl.testing import absltest +import federated_language from tensorflow_federated.python.core.impl.executors import data_conversions -from tensorflow_federated.python.core.impl.types import placements class DataConversionsTest(absltest.TestCase): @@ -23,12 +23,12 @@ class DataConversionsTest(absltest.TestCase): def test_converts_placement_keyed_to_string_keyed(self): num_clients = 10 placement_keyed_mapping = { - placements.SERVER: 1, - placements.CLIENTS: num_clients, + federated_language.SERVER: 1, + federated_language.CLIENTS: num_clients, } expected_string_keyed_mapping = { - placements.SERVER.uri: 1, - placements.CLIENTS.uri: num_clients, + federated_language.SERVER.uri: 1, + federated_language.CLIENTS.uri: num_clients, } string_keyed_mapping = ( @@ -40,7 +40,10 @@ def test_converts_placement_keyed_to_string_keyed(self): self.assertEqual(string_keyed_mapping, expected_string_keyed_mapping) def test_raises_string_keyed_mapping(self): - string_keyed_mapping = {placements.SERVER.uri: 1, placements.CLIENTS.uri: 5} + string_keyed_mapping = { + federated_language.SERVER.uri: 1, + federated_language.CLIENTS.uri: 5, + } with self.assertRaises(TypeError): data_conversions.convert_cardinalities_dict_to_string_keyed( @@ -49,8 +52,8 @@ def test_raises_string_keyed_mapping(self): def test_raises_non_integer_values(self): placement_keyed_non_integer_valued_mapping = { - placements.SERVER: 1.0, - placements.CLIENTS: 10.0, + federated_language.SERVER: 1.0, + federated_language.CLIENTS: 10.0, } with self.assertRaises(TypeError): diff --git a/tensorflow_federated/python/core/impl/executors/executor_base.py b/tensorflow_federated/python/core/impl/executors/executor_base.py deleted file mode 100644 index 860366a283..0000000000 --- a/tensorflow_federated/python/core/impl/executors/executor_base.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A base Python interface for all types of executors.""" - -import abc -from typing import Optional - -from tensorflow_federated.python.core.impl.executors import executor_value_base as evb - - -class Executor(metaclass=abc.ABCMeta): - """Represents the abstract interface that all executors must implement.""" - - # TODO: b/134543154 - Migrate the reference executor over this new interface. - - # TODO: b/134543154 - Standardize and document the kinds of values that can be - # embedded and must be understood by all executor implementations, possibly - # factoring out parts of reference executor's `to_representation_for_type()`. - - @abc.abstractmethod - def close(self): - """Release resources associated with this Executor, if any. - - If the executor has one or more target Executors, implementation of this - method must close them. - """ - raise NotImplementedError - - @abc.abstractmethod - async def create_value(self, value, type_spec=None) -> evb.ExecutorValue: - """A coroutine that creates embedded value from `value` of type `type_spec`. - - This function is used to embed a value within the executor. The argument - can be one of the plain Python types, a nested structure, a representation - of a TFF computation, etc. Once embedded, the value can be further passed - around within the executor. For functional values, embedding them prior to - invocation potentially allows the executor to amortize overhead across - multiple calls. - - Args: - value: An object that represents the value to embed within the executor. - type_spec: An optional `tff.Type` of the value represented by this object, - or something convertible to it. The type can only be omitted if the - value is a instance of `tff.TypedObject`. - - Returns: - An instance of `ExecutorValue` that represents the embedded value. - """ - raise NotImplementedError - - @abc.abstractmethod - async def create_call( - self, comp: evb.ExecutorValue, arg: Optional[evb.ExecutorValue] = None - ) -> evb.ExecutorValue: - """A coroutine that creates a call to `comp` with optional argument `arg`. - - Args: - comp: The computation to invoke. It must have been first embedded in the - executor by calling `create_value()` on it first. - arg: An optional argument of the call, or `None` if no argument was - supplied. If it is present, it must have been embedded in the executor - by calling `create_value()` on it first. - - Returns: - An instance of `ExecutorValue` that represents the constructed call. - """ - raise NotImplementedError - - @abc.abstractmethod - async def create_struct(self, elements) -> evb.ExecutorValue: - """A coroutine that creates a tuple of `elements`. - - Args: - elements: A collection of `ExecutorValue`s to create a tuple from. The - collection may be of any kind accepted by `structure.from_container`, - including dictionaries and lists. The `ExecutorValues` in the container - must have been created by calling `create_value` on this executor. - - Returns: - An instance of `ExecutorValue` that represents the constructed tuple. - """ - raise NotImplementedError - - @abc.abstractmethod - async def create_selection(self, source, index) -> evb.ExecutorValue: - """A coroutine that creates a selection from `source`. - - Args: - source: The source to select from. The source must have been embedded in - this executor by invoking `create_value()` on it first. - index: An integer index to select. - - Returns: - An instance of `ExecutorValue` that represents the constructed selection. - """ - raise NotImplementedError diff --git a/tensorflow_federated/python/core/impl/executors/executor_bindings.py b/tensorflow_federated/python/core/impl/executors/executor_bindings.py index a433212ace..842b30d8ab 100644 --- a/tensorflow_federated/python/core/impl/executors/executor_bindings.py +++ b/tensorflow_federated/python/core/impl/executors/executor_bindings.py @@ -15,9 +15,10 @@ from collections.abc import Mapping +import federated_language + from tensorflow_federated.cc.core.impl.executors import executor_bindings from tensorflow_federated.python.core.impl.executors import data_conversions -from tensorflow_federated.python.core.impl.types import placements # Import classes. OwnedValueId = executor_bindings.OwnedValueId @@ -40,7 +41,7 @@ def create_federating_executor( inner_server_executor: executor_bindings.Executor, inner_client_executor: executor_bindings.Executor, - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> executor_bindings.Executor: """Constructs a FederatingExecutor with a specified placement.""" uri_cardinalities = ( @@ -53,7 +54,7 @@ def create_federating_executor( def create_remote_executor( channel: GRPCChannel, - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> executor_bindings.Executor: """Constructs a RemoteExecutor proxying service on `channel`.""" uri_cardinalities = ( @@ -64,7 +65,7 @@ def create_remote_executor( def create_streaming_remote_executor( channel: GRPCChannel, - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> executor_bindings.Executor: """Constructs a StreamingRemoteExecutor proxying service on `channel`.""" uri_cardinalities = ( @@ -77,7 +78,7 @@ def create_streaming_remote_executor( def create_composing_child( executor: executor_bindings.Executor, - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> executor_bindings.Executor: """Constructs a ComposingChild with specified cardinalities.""" uri_cardinalities = ( diff --git a/tensorflow_federated/python/core/impl/executors/executor_bindings_test.py b/tensorflow_federated/python/core/impl/executors/executor_bindings_test.py index 16663d2856..cd7b7eeec4 100644 --- a/tensorflow_federated/python/core/impl/executors/executor_bindings_test.py +++ b/tensorflow_federated/python/core/impl/executors/executor_bindings_test.py @@ -13,11 +13,11 @@ # limitations under the License. from absl.testing import absltest +import federated_language import portpicker from tensorflow_federated.python.core.impl.executors import executor_bindings from tensorflow_federated.python.core.impl.executors import executor_test_utils_bindings -from tensorflow_federated.python.core.impl.types import placements class ReferenceResolvingExecutorBindingsTest(absltest.TestCase): @@ -38,7 +38,7 @@ def test_construction(self): children = [ executor_bindings.create_composing_child( mock_child_executor, - {placements.CLIENTS: 0}, + {federated_language.CLIENTS: 0}, ) ] try: @@ -64,7 +64,7 @@ class FederatingExecutorBindingsTest(absltest.TestCase): def test_construction(self): mock_server_executor = executor_test_utils_bindings.create_mock_executor() mock_client_executor = executor_test_utils_bindings.create_mock_executor() - cardinalities = {placements.CLIENTS: 0} + cardinalities = {federated_language.CLIENTS: 0} try: executor_bindings.create_federating_executor( @@ -83,7 +83,7 @@ def test_construction_with_insecure_channel(self): try: executor_bindings.create_remote_executor( channel, - cardinalities={placements.CLIENTS: 10}, + cardinalities={federated_language.CLIENTS: 10}, ) except Exception: # pylint: disable=broad-except self.fail('Raised `Exception` unexpectedly.') diff --git a/tensorflow_federated/python/core/impl/executors/executor_factory.py b/tensorflow_federated/python/core/impl/executors/executor_factory.py deleted file mode 100644 index 9d9fba40d3..0000000000 --- a/tensorflow_federated/python/core/impl/executors/executor_factory.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ExecutorFactory interface and simple implementation.""" - -import abc -from collections.abc import MutableMapping - -from tensorflow_federated.python.core.impl.executors import executor_base -from tensorflow_federated.python.core.impl.types import placements - -CardinalitiesType = MutableMapping[placements.PlacementLiteral, int] - - -class ExecutorFactory(metaclass=abc.ABCMeta): - """Interface defining executor factories. - - `ExecutorFactory` should be considered to own the executors it creates; it - is responsible for their instantiation and management. - - `ExecutorFactory` exposes two methods, `create_executor` and - `clean_up_executors`. There is a particular coupling between these two - methods; any executor returned by `create_executor` should not be used - after `clean_up_executors` has been called without reinitialization. That is, - `create_executor` should be called again, and `ExecutorFactory` will ensure - that the returned executor is safe for use. - """ - - @abc.abstractmethod - def create_executor( - self, cardinalities: CardinalitiesType - ) -> executor_base.Executor: - """Abstract method to construct instance of `executor_base.Executor`. - - `create_executor` must accept a dict mapping - `placements.PlacementLiterals` to `ints`, and return an - `executor_base.Executor`. - - Args: - cardinalities: a dict mapping instances of `placements.PlacementLiteral` - to ints, specifying the population size at each placement. - - Returns: - Instance of `executor_base.Executor`. - """ - pass - - @abc.abstractmethod - def clean_up_executor(self, cardinalities: CardinalitiesType): - """Releases any resources associated to the given cardinalities. - - Note that calling this method may invalidate the state of any executors - which have previously been returned by the factory with the `cardinalities` - argument ; `create_executor` should be called again if a new executor which - is safe to use is desired. - - Args: - cardinalities: The cardinalities of the executor whose state we wish to - clear. - """ - pass diff --git a/tensorflow_federated/python/core/impl/executors/executor_utils.py b/tensorflow_federated/python/core/impl/executors/executor_utils.py index 49e4fa89da..96a1f90a1f 100644 --- a/tensorflow_federated/python/core/impl/executors/executor_utils.py +++ b/tensorflow_federated/python/core/impl/executors/executor_utils.py @@ -15,14 +15,14 @@ from typing import Optional +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import typed_object def reconcile_value_with_type_spec( - value: object, type_spec: computation_types.Type -) -> computation_types.Type: + value: object, type_spec: federated_language.Type +) -> federated_language.Type: """Reconciles the type of `value` with the given `type_spec`. The currently implemented logic only performs reconciliation of `value` and @@ -48,7 +48,7 @@ def reconcile_value_with_type_spec( TypeError: If the `value` type and `type_spec` are incompatible, or if the type cannot be determined.. """ - if isinstance(value, typed_object.TypedObject): + if isinstance(value, federated_language.TypedObject): return reconcile_value_type_with_type_spec(value.type_signature, type_spec) elif type_spec is not None: return type_spec @@ -59,9 +59,9 @@ def reconcile_value_with_type_spec( def reconcile_value_type_with_type_spec( - value_type: computation_types.Type, - type_spec: Optional[computation_types.Type], -) -> computation_types.Type: + value_type: federated_language.Type, + type_spec: Optional[federated_language.Type], +) -> federated_language.Type: """Reconciles a pair of types. Args: @@ -75,9 +75,9 @@ def reconcile_value_type_with_type_spec( Raises: TypeError: If arguments are of incompatible types. """ - py_typecheck.check_type(value_type, computation_types.Type) + py_typecheck.check_type(value_type, federated_language.Type) if type_spec is not None: - py_typecheck.check_type(value_type, computation_types.Type) + py_typecheck.check_type(value_type, federated_language.Type) if not value_type.is_equivalent_to(type_spec): raise TypeError( 'Expected a value of type {}, found {}.'.format(type_spec, value_type) diff --git a/tensorflow_federated/python/core/impl/executors/executor_utils_test.py b/tensorflow_federated/python/core/impl/executors/executor_utils_test.py index 31f6e58d58..2f97ea9149 100644 --- a/tensorflow_federated/python/core/impl/executors/executor_utils_test.py +++ b/tensorflow_federated/python/core/impl/executors/executor_utils_test.py @@ -14,11 +14,9 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np - -from tensorflow_federated.python.core.impl.compiler import building_block_factory from tensorflow_federated.python.core.impl.executors import executor_utils -from tensorflow_federated.python.core.impl.types import computation_types class TypeUtilsTest(parameterized.TestCase): @@ -26,25 +24,25 @@ class TypeUtilsTest(parameterized.TestCase): @parameterized.named_parameters([ ( 'buiding_block_and_type_spec', - building_block_factory.create_identity( - computation_types.TensorType(np.int32) + federated_language.framework.create_identity( + federated_language.TensorType(np.int32) ), - computation_types.FunctionType(np.int32, np.int32), - computation_types.FunctionType(np.int32, np.int32), + federated_language.FunctionType(np.int32, np.int32), + federated_language.FunctionType(np.int32, np.int32), ), ( 'buiding_block_and_none', - building_block_factory.create_identity( - computation_types.TensorType(np.int32) + federated_language.framework.create_identity( + federated_language.TensorType(np.int32) ), None, - computation_types.FunctionType(np.int32, np.int32), + federated_language.FunctionType(np.int32, np.int32), ), ( 'int_and_type_spec', 10, - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ]) def test_reconcile_value_with_type_spec_returns_type( @@ -58,10 +56,10 @@ def test_reconcile_value_with_type_spec_returns_type( @parameterized.named_parameters([ ( 'building_block_and_bad_type_spec', - building_block_factory.create_identity( - computation_types.TensorType(np.int32) + federated_language.framework.create_identity( + federated_language.TensorType(np.int32) ), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ('int_and_none', 10, None), ]) @@ -74,15 +72,15 @@ def test_reconcile_value_with_type_spec_raises_type_error( @parameterized.named_parameters([ ( 'value_type_and_type_spec', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ( 'value_type_and_none', - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), None, - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ), ]) def test_reconcile_value_type_with_type_spec_returns_type( @@ -96,8 +94,8 @@ def test_reconcile_value_type_with_type_spec_returns_type( def test_reconcile_value_type_with_type_spec_raises_type_error_value_type_and_bad_type_spec( self, ): - value_type = computation_types.TensorType(np.int32) - type_spec = computation_types.TensorType(np.str_) + value_type = federated_language.TensorType(np.int32) + type_spec = federated_language.TensorType(np.str_) with self.assertRaises(TypeError): executor_utils.reconcile_value_type_with_type_spec(value_type, type_spec) diff --git a/tensorflow_federated/python/core/impl/executors/executor_value_base.py b/tensorflow_federated/python/core/impl/executors/executor_value_base.py deleted file mode 100644 index 66ffb63130..0000000000 --- a/tensorflow_federated/python/core/impl/executors/executor_value_base.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A base Python interface for values embedded in executors.""" - -import abc - -from tensorflow_federated.python.core.impl.types import typed_object - - -class ExecutorValue(abc.ABC, typed_object.TypedObject): - """Represents the abstract interface for values embedded within executors. - - The embedded values may represent computations in-flight that may materialize - in the future or fail before they materialize. - """ - - @property - @abc.abstractmethod - def reference(self): - """Returns a reference to the value without transferring ownership. - - A reference is an opaque object that is understood by the executors that - produced the value, therefore: - - 1. Executors need to preserve this contract in their implementation. - 2. Users of Executors should not need to depend on this value. - """ - raise NotImplementedError - - @abc.abstractmethod - async def compute(self): - """A coroutine that asynchronously returns the computed form of the value. - - The computed form of a value can take a number of forms, such as primitive - types in Python, numpy arrays, or even eager tensors in case this is an - eager executor, or an executor backed by an eager one. - - Returns: - The computed form of the value, as defined above. - """ - raise NotImplementedError diff --git a/tensorflow_federated/python/core/impl/executors/executors_errors.py b/tensorflow_federated/python/core/impl/executors/executors_errors.py index 1f9c030383..76d7d056e5 100644 --- a/tensorflow_federated/python/core/impl/executors/executors_errors.py +++ b/tensorflow_federated/python/core/impl/executors/executors_errors.py @@ -16,15 +16,14 @@ import typing from typing import Union +import federated_language import grpc from typing_extensions import TypeGuard -class RetryableError(Exception): - """Raised when execution fails and can be retried.""" - - -class RetryableGRPCError(RetryableError, grpc.RpcError, grpc.Call): +class RetryableGRPCError( + federated_language.framework.RetryableError, grpc.RpcError, grpc.Call +): """Raised when execution fails across a gRPC connection and can be retried.""" def __init__(self, call: grpc.Call): @@ -51,7 +50,7 @@ def get_grpc_retryable_error_codes() -> set[grpc.StatusCode]: ]) -class RetryableAbslStatusError(RetryableError): +class RetryableAbslStatusError(federated_language.framework.RetryableError): """Raised when execution fails with an absl status error and can be retried.""" diff --git a/tensorflow_federated/python/core/impl/executors/remote_executor.py b/tensorflow_federated/python/core/impl/executors/remote_executor.py index 97224d9ce3..779464ae2f 100644 --- a/tensorflow_federated/python/core/impl/executors/remote_executor.py +++ b/tensorflow_federated/python/core/impl/executors/remote_executor.py @@ -18,22 +18,18 @@ import weakref from absl import logging +import federated_language import grpc from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.common_libs import tracing -from tensorflow_federated.python.core.impl.executors import executor_base -from tensorflow_federated.python.core.impl.executors import executor_value_base from tensorflow_federated.python.core.impl.executors import executors_errors from tensorflow_federated.python.core.impl.executors import remote_executor_stub from tensorflow_federated.python.core.impl.executors import value_serialization -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -class RemoteValue(executor_value_base.ExecutorValue): +class RemoteValue(federated_language.framework.ExecutorValue): """A reference to a value embedded in a remotely deployed executor service.""" def __init__( @@ -48,13 +44,13 @@ def __init__( Args: value_ref: An instance of `executor_pb2.ValueRef` returned by the remote executor service. - type_spec: An instance of `computation_types.Type`. + type_spec: An instance of `federated_language.Type`. executor: The executor that created this value. dispose_at_exit: The flag to disable calling dispose on the object at deletion. """ py_typecheck.check_type(value_ref, executor_pb2.ValueRef) - py_typecheck.check_type(type_spec, computation_types.Type) + py_typecheck.check_type(type_spec, federated_language.Type) py_typecheck.check_type(executor, RemoteExecutor) self._value_ref = value_ref self._type_signature = type_spec @@ -78,12 +74,12 @@ def type_signature(self): def reference(self): return self._value_ref - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def compute(self): return await self._executor._compute(self._value_ref, self._type_signature) # pylint: disable=protected-access -class RemoteExecutor(executor_base.Executor): +class RemoteExecutor(federated_language.framework.Executor): """The remote executor is a local proxy for a remote executor instance.""" # TODO: b/134543154 - Switch to using an asynchronous gRPC client so we don't @@ -152,9 +148,12 @@ def _dispose( ) self._stub.dispose(dispose_request) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) def set_cardinalities( - self, cardinalities: Mapping[placements.PlacementLiteral, int] + self, + cardinalities: Mapping[ + federated_language.framework.PlacementLiteral, int + ], ): if self._executor_id is not None: self._clear_executor() @@ -169,14 +168,14 @@ def set_cardinalities( executor=self._executor_id ) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) def _clear_executor(self): if self._executor_id is None: return request = executor_pb2.DisposeExecutorRequest(executor=self._executor_id) try: self._stub.dispose_executor(request) - except (grpc.RpcError, executors_errors.RetryableError): + except (grpc.RpcError, federated_language.framework.RetryableError): logging.debug( 'RPC error caught during attempt to clear state on the ' 'server; this likely indicates a broken connection, and ' @@ -186,9 +185,9 @@ def _clear_executor(self): self._dispose_request = None return - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def create_value_stream_structs( - self, value, type_spec: computation_types.StructType + self, value, type_spec: federated_language.StructType ): value = structure.from_container(value) if len(value) != len(type_spec): @@ -213,16 +212,16 @@ async def create_value_stream_structs( value_refs = await asyncio.gather(*value_refs) return await self.create_struct(value_refs) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def create_value(self, value, type_spec=None): self._check_has_executor_id() - @tracing.trace + @federated_language.framework.trace def serialize_value(): return value_serialization.serialize_value(value, type_spec) if self._stream_structs and isinstance( - type_spec, computation_types.StructType + type_spec, federated_language.StructType ): return await self.create_value_stream_structs(value, type_spec) @@ -234,11 +233,13 @@ def serialize_value(): py_typecheck.check_type(response, executor_pb2.CreateValueResponse) return RemoteValue(response.value_ref, type_spec, self) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def create_call(self, comp, arg=None): self._check_has_executor_id() py_typecheck.check_type(comp, RemoteValue) - py_typecheck.check_type(comp.type_signature, computation_types.FunctionType) + py_typecheck.check_type( + comp.type_signature, federated_language.FunctionType + ) if arg is not None: py_typecheck.check_type(arg, RemoteValue) create_call_request = executor_pb2.CreateCallRequest( @@ -250,7 +251,7 @@ async def create_call(self, comp, arg=None): py_typecheck.check_type(response, executor_pb2.CreateCallResponse) return RemoteValue(response.value_ref, comp.type_signature.result, self) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def create_struct(self, elements): self._check_has_executor_id() constructed_anon_tuple = structure.from_container(elements) @@ -264,7 +265,7 @@ async def create_struct(self, elements): ) ) type_elem.append((k, v.type_signature) if k else v.type_signature) - result_type = computation_types.StructType(type_elem) + result_type = federated_language.StructType(type_elem) request = executor_pb2.CreateStructRequest( executor=self._executor_id, element=proto_elem ) @@ -272,11 +273,13 @@ async def create_struct(self, elements): py_typecheck.check_type(response, executor_pb2.CreateStructResponse) return RemoteValue(response.value_ref, result_type, self) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def create_selection(self, source, index): self._check_has_executor_id() py_typecheck.check_type(source, RemoteValue) - py_typecheck.check_type(source.type_signature, computation_types.StructType) + py_typecheck.check_type( + source.type_signature, federated_language.StructType + ) py_typecheck.check_type(index, int) result_type = source.type_signature[index] request = executor_pb2.CreateSelectionRequest( @@ -286,9 +289,9 @@ async def create_selection(self, source, index): py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse) return RemoteValue(response.value_ref, result_type, self) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def _compute_stream_structs( - self, value_ref, type_spec: computation_types.StructType + self, value_ref, type_spec: federated_language.StructType ): py_typecheck.check_type(value_ref, executor_pb2.ValueRef) values = [] @@ -310,13 +313,13 @@ async def per_element(source, index, element_spec): zip(structure.name_list_with_nones(type_spec), values) ) - @tracing.trace(span=True) + @federated_language.framework.trace(span=True) async def _compute(self, value_ref, type_spec): self._check_has_executor_id() py_typecheck.check_type(value_ref, executor_pb2.ValueRef) if self._stream_structs and isinstance( - type_spec, computation_types.StructType + type_spec, federated_language.StructType ): return await self._compute_stream_structs(value_ref, type_spec) diff --git a/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub.py b/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub.py index e9ebd66f10..394b427d5f 100644 --- a/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub.py +++ b/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub.py @@ -14,19 +14,19 @@ """A stub connects to a remote executor over gRPC.""" from absl import logging +import federated_language import grpc from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.proto.v0 import executor_pb2_grpc -from tensorflow_federated.python.common_libs import tracing from tensorflow_federated.python.core.impl.executors import executors_errors from tensorflow_federated.python.core.impl.executors import remote_executor_stub -@tracing.trace(span=True) +@federated_language.framework.trace(span=True) def _request(rpc_func, request): """Populates trace context and reraises gRPC errors with retryable info.""" - with tracing.wrap_rpc_in_trace_context(): + with federated_language.framework.wrap_rpc_in_trace_context(): try: return rpc_func(request) except grpc.RpcError as e: diff --git a/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub_test.py b/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub_test.py index 5d36ba0323..e334ec077f 100644 --- a/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub_test.py +++ b/tensorflow_federated/python/core/impl/executors/remote_executor_grpc_stub_test.py @@ -15,16 +15,15 @@ from unittest import mock from absl.testing import absltest +import federated_language import grpc import portpicker from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.proto.v0 import executor_pb2_grpc -from tensorflow_federated.python.core.impl.computation import computation_impl from tensorflow_federated.python.core.impl.executors import executors_errors from tensorflow_federated.python.core.impl.executors import remote_executor_grpc_stub from tensorflow_federated.python.core.impl.executors import value_serialization -from tensorflow_federated.python.core.impl.federated_context import federated_computation def create_stub(): @@ -51,7 +50,7 @@ def trailing_metadata(self): raise NotImplementedError() -@federated_computation.federated_computation() +@federated_language.federated_computation() def _empty_struct(): return () @@ -73,7 +72,9 @@ def test_grpc_connectivity(self): class RemoteExecutorGrpcStubTest(absltest.TestCase): def test_compute_returns_result(self, mock_executor_grpc_stub): - proto = computation_impl.ConcreteComputation.get_proto(_empty_struct) + proto = federated_language.framework.ConcreteComputation.get_proto( + _empty_struct + ) value = executor_pb2.Value(computation=proto) response = executor_pb2.ComputeResponse(value=value) instance = mock_executor_grpc_stub.return_value @@ -100,7 +101,7 @@ def test_compute_raises_retryable_error_on_grpc_error_unavailable( ) stub = create_stub() - with self.assertRaises(executors_errors.RetryableError): + with self.assertRaises(federated_language.framework.RetryableError): stub.compute( executor_pb2.ComputeRequest(value_ref=executor_pb2.ValueRef()) ) @@ -146,7 +147,7 @@ def test_create_value_raises_retryable_error_on_grpc_error_unavailable( ) stub = create_stub() - with self.assertRaises(executors_errors.RetryableError): + with self.assertRaises(federated_language.framework.RetryableError): stub.create_value(request=executor_pb2.CreateValueRequest()) def test_create_value_reraises_grpc_error(self, mock_executor_grpc_stub): @@ -188,7 +189,7 @@ def test_create_call_raises_retryable_error_on_grpc_error_unavailable( ) stub = create_stub() - with self.assertRaises(executors_errors.RetryableError): + with self.assertRaises(federated_language.framework.RetryableError): stub.create_call(request=executor_pb2.CreateCallRequest()) def test_create_call_reraises_grpc_error(self, mock_executor_grpc_stub): @@ -231,7 +232,7 @@ def test_create_struct_raises_retryable_error_on_grpc_error_unavailable( ) stub = create_stub() - with self.assertRaises(executors_errors.RetryableError): + with self.assertRaises(federated_language.framework.RetryableError): stub.create_struct(request=executor_pb2.CreateStructRequest()) def test_create_struct_reraises_grpc_error(self, mock_executor_grpc_stub): @@ -276,7 +277,7 @@ def test_create_selection_raises_retryable_error_on_grpc_error_unavailable( ) stub = create_stub() - with self.assertRaises(executors_errors.RetryableError): + with self.assertRaises(federated_language.framework.RetryableError): stub.create_selection(request=executor_pb2.CreateSelectionRequest()) def test_create_selection_reraises_non_retryable_grpc_error( diff --git a/tensorflow_federated/python/core/impl/executors/remote_executor_test.py b/tensorflow_federated/python/core/impl/executors/remote_executor_test.py index e948caa454..1c2964f420 100644 --- a/tensorflow_federated/python/core/impl/executors/remote_executor_test.py +++ b/tensorflow_federated/python/core/impl/executors/remote_executor_test.py @@ -17,17 +17,14 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import grpc import numpy as np from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_impl from tensorflow_federated.python.core.impl.executors import remote_executor from tensorflow_federated.python.core.impl.executors import remote_executor_stub -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements def _raise_non_retryable_grpc_error(*args): @@ -43,10 +40,10 @@ def _set_cardinalities_with_mock( mock_stub.get_executor.return_value = executor_pb2.GetExecutorResponse( executor=executor_pb2.ExecutorId(id='id') ) - executor.set_cardinalities({placements.CLIENTS: 3}) + executor.set_cardinalities({federated_language.CLIENTS: 3}) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _empty_struct(): return () @@ -61,15 +58,17 @@ class RemoteValueTest(parameterized.TestCase): def test_compute_returns_result_with_stream_structs( self, stream_structs, mock_stub ): - proto = computation_impl.ConcreteComputation.get_proto(_empty_struct) + proto = federated_language.framework.ConcreteComputation.get_proto( + _empty_struct + ) value = executor_pb2.Value(computation=proto) mock_stub.compute.return_value = executor_pb2.ComputeResponse(value=value) executor = remote_executor.RemoteExecutor( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - executor.set_cardinalities({placements.CLIENTS: 3}) - type_signature = computation_types.FunctionType(None, np.int32) + executor.set_cardinalities({federated_language.CLIENTS: 3}) + type_signature = federated_language.FunctionType(None, np.int32) remote_value = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -91,7 +90,7 @@ def test_compute_reraises_grpc_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.FunctionType(None, np.int32) + type_signature = federated_language.FunctionType(None, np.int32) comp = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -114,7 +113,7 @@ def test_compute_reraises_type_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.FunctionType(None, np.int32) + type_signature = federated_language.FunctionType(None, np.int32) comp = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -140,7 +139,7 @@ def test_set_cardinalities_returns_none_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - result = executor.set_cardinalities({placements.CLIENTS: 3}) + result = executor.set_cardinalities({federated_language.CLIENTS: 3}) self.assertIsNone(result) @parameterized.named_parameters( @@ -158,7 +157,7 @@ def test_create_value_returns_remote_value_with_stream_structs( _set_cardinalities_with_mock(executor, mock_stub) result = asyncio.run( - executor.create_value(1, computation_types.TensorType(np.int32)) + executor.create_value(1, federated_language.TensorType(np.int32)) ) mock_stub.create_value.assert_called_once() @@ -182,7 +181,7 @@ def test_create_value_reraises_grpc_error_with_stream_structs( with self.assertRaises(grpc.RpcError) as context: asyncio.run( - executor.create_value(1, computation_types.TensorType(np.int32)) + executor.create_value(1, federated_language.TensorType(np.int32)) ) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED) @@ -203,7 +202,7 @@ def test_create_value_reraises_type_error_with_stream_structs( with self.assertRaises(TypeError): asyncio.run( - executor.create_value(1, computation_types.TensorType(np.int32)) + executor.create_value(1, federated_language.TensorType(np.int32)) ) @parameterized.named_parameters( @@ -239,23 +238,23 @@ def test_create_value_for_nested_struct_with_stream_structs( ('c', np.zeros(shape=tensor_shape, dtype=np.int32)), ]) - type_signature = computation_types.StructType([ + type_signature = federated_language.StructType([ ( 'a', - computation_types.TensorType(shape=tensor_shape, dtype=np.int32), + federated_language.TensorType(shape=tensor_shape, dtype=np.int32), ), ( 'b', - computation_types.StructType([ + federated_language.StructType([ ( 'b0', - computation_types.TensorType( + federated_language.TensorType( shape=tensor_shape, dtype=np.int32 ), ), ( 'b1', - computation_types.TensorType( + federated_language.TensorType( shape=tensor_shape, dtype=np.int32 ), ), @@ -263,7 +262,7 @@ def test_create_value_for_nested_struct_with_stream_structs( ), ( 'c', - computation_types.TensorType(shape=tensor_shape, dtype=np.int32), + federated_language.TensorType(shape=tensor_shape, dtype=np.int32), ), ]) @@ -288,7 +287,7 @@ def test_create_call_returns_remote_value_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.FunctionType(None, np.int32) + type_signature = federated_language.FunctionType(None, np.int32) fn = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -313,7 +312,7 @@ def test_create_call_reraises_grpc_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.FunctionType(None, np.int32) + type_signature = federated_language.FunctionType(None, np.int32) comp = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -336,7 +335,7 @@ def test_create_call_reraises_type_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.FunctionType(None, np.int32) + type_signature = federated_language.FunctionType(None, np.int32) comp = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -357,7 +356,7 @@ def test_create_struct_returns_remote_value_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.TensorType(np.int32) + type_signature = federated_language.TensorType(np.int32) value_1 = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -385,7 +384,7 @@ def test_create_struct_reraises_grpc_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.TensorType(np.int32) + type_signature = federated_language.TensorType(np.int32) value_1 = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -411,7 +410,7 @@ def test_create_struct_reraises_type_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.TensorType(np.int32) + type_signature = federated_language.TensorType(np.int32) value_1 = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -437,7 +436,7 @@ def test_create_selection_returns_remote_value_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.StructType([np.int32, np.int32]) + type_signature = federated_language.StructType([np.int32, np.int32]) source = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -462,7 +461,7 @@ def test_create_selection_reraises_non_retryable_grpc_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.StructType([np.int32, np.int32]) + type_signature = federated_language.StructType([np.int32, np.int32]) source = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) @@ -485,7 +484,7 @@ def test_create_selection_reraises_type_error_with_stream_structs( mock_stub, stream_structs=stream_structs ) _set_cardinalities_with_mock(executor, mock_stub) - type_signature = computation_types.StructType([np.int32, np.int32]) + type_signature = federated_language.StructType([np.int32, np.int32]) source = remote_executor.RemoteValue( executor_pb2.ValueRef(), type_signature, executor ) diff --git a/tensorflow_federated/python/core/impl/executors/value_serialization.py b/tensorflow_federated/python/core/impl/executors/value_serialization.py index c3a1853328..0203d3151d 100644 --- a/tensorflow_federated/python/core/impl/executors/value_serialization.py +++ b/tensorflow_federated/python/core/impl/executors/value_serialization.py @@ -17,46 +17,37 @@ import typing from typing import Optional +import federated_language +from federated_language.proto import array_pb2 +from federated_language.proto import computation_pb2 import numpy as np import tree -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.common_libs import tracing -from tensorflow_federated.python.core.impl.compiler import array -from tensorflow_federated.python.core.impl.computation import computation_impl from tensorflow_federated.python.core.impl.executors import executor_utils -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import dtype_utils -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import type_serialization -_SerializeReturnType = tuple[executor_pb2.Value, computation_types.Type] -_DeserializeReturnType = tuple[object, computation_types.Type] +_SerializeReturnType = tuple[executor_pb2.Value, federated_language.Type] +_DeserializeReturnType = tuple[object, federated_language.Type] -@tracing.trace +@federated_language.framework.trace def _serialize_computation( comp: computation_pb2.Computation, - type_spec: Optional[computation_types.Type], + type_spec: Optional[federated_language.Type], ) -> _SerializeReturnType: """Serializes a TFF computation.""" type_spec = executor_utils.reconcile_value_type_with_type_spec( - type_serialization.deserialize_type(comp.type), type_spec + federated_language.framework.deserialize_type(comp.type), type_spec ) return executor_pb2.Value(computation=comp), type_spec -@tracing.trace +@federated_language.framework.trace def _serialize_tensor_value( - value: object, type_spec: computation_types.TensorType -) -> tuple[executor_pb2.Value, computation_types.TensorType]: + value: object, type_spec: federated_language.TensorType +) -> tuple[executor_pb2.Value, federated_language.TensorType]: """Serializes a tensor value into `executor_pb2.Value`. Args: @@ -76,13 +67,13 @@ def _serialize_tensor_value( """ # It is necessary to coerce Python `list` and `tuple` to a numpy value, - # because these types are not an `array.Array`, but can be serialized as a + # because these types are not an `federated_language.Array`, but can be serialized as a # single `tff.TensorType`. Additionally, it is safe to coerce these kinds of # values to a numpy value of type `type_spec.dtype.type` if each element in # the sequence is compatible with `type_spec.dtype.type`. if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): if not all( - array.is_compatible_dtype(x, type_spec.dtype.type) + federated_language.array_is_compatible_dtype(x, type_spec.dtype.type) for x in tree.flatten(value) ): raise TypeError( @@ -90,14 +81,14 @@ def _serialize_tensor_value( f' {type_spec.dtype.type}.' ) value = np.asarray(value, type_spec.dtype.type) - elif not isinstance(value, typing.get_args(array.Array)): + elif not isinstance(value, typing.get_args(federated_language.Array)): raise NotImplementedError(f'Unexpected `value` found: {type(value)}.') else: # This is required because in Python 3.9 `isinstance` cannot accept a # `Union` of types and `pytype` does not parse `typing.get_args`. - value = typing.cast(array.Array, value) + value = typing.cast(federated_language.Array, value) - if not array.is_compatible_shape(value, type_spec.shape): + if not federated_language.array_is_compatible_shape(value, type_spec.shape): if isinstance(value, (np.ndarray, np.generic)): shape = value.shape else: @@ -107,7 +98,9 @@ def _serialize_tensor_value( f' {type_spec.shape}.' ) - if not array.is_compatible_dtype(value, type_spec.dtype.type): + if not federated_language.array_is_compatible_dtype( + value, type_spec.dtype.type + ): if isinstance(value, (np.ndarray, np.generic)): dtype = value.dtype.type else: @@ -120,36 +113,40 @@ def _serialize_tensor_value( # Repeated fields are used for strings and constants to maintain compatibility # with other external environments. if ( - array_shape.is_shape_scalar(type_spec.shape) + federated_language.array_shape_is_scalar(type_spec.shape) or type_spec.dtype.type is np.str_ ): - array_pb = array.to_proto(value, dtype_hint=type_spec.dtype.type) + array_pb = federated_language.array_to_proto( + value, dtype_hint=type_spec.dtype.type + ) else: - array_pb = array.to_proto_content(value, dtype_hint=type_spec.dtype.type) + array_pb = federated_language.array_to_proto_content( + value, dtype_hint=type_spec.dtype.type + ) value_pb = executor_pb2.Value(array=array_pb) return value_pb, type_spec -@tracing.trace +@federated_language.framework.trace def _serialize_array( - value: array.Array, - type_spec: computation_types.TensorType, + value: federated_language.Array, + type_spec: federated_language.TensorType, ) -> array_pb2.Array: value_proto, _ = _serialize_tensor_value(value, type_spec) return value_proto.array -@tracing.trace +@federated_language.framework.trace def _serialize_sequence_value( value: Sequence[object], - type_spec: computation_types.SequenceType, + type_spec: federated_language.SequenceType, ) -> _SerializeReturnType: """Serializes a sequence into `executor_pb2.Value`. Args: value: A list of values convertible to (potentially structures of) tensors. - type_spec: A `computation_types.Type` specifying the TFF sequence type of + type_spec: A `federated_language.Type` specifying the TFF sequence type of `value.` Returns: @@ -158,7 +155,7 @@ def _serialize_sequence_value( and `type_spec` is the type of the serialized value. """ element_type = type_spec.element - if not type_analysis.is_structure_of_tensors(element_type): + if not federated_language.framework.is_structure_of_tensors(element_type): raise ValueError( 'Expected `element_type` to contain only `tff.StructType` or' f' `tff.TensorType`, found {element_type}.' @@ -166,7 +163,7 @@ def _serialize_sequence_value( def _flatten(value, type_spec): """Flatten `value` according to `type_spec`.""" - if isinstance(type_spec, computation_types.StructType): + if isinstance(type_spec, federated_language.StructType): if isinstance(value, Mapping): value = value.values() @@ -194,7 +191,7 @@ def _flatten(value, type_spec): ) elements_proto.append(element_proto) - element_type_proto = type_serialization.serialize_type(element_type) + element_type_proto = federated_language.framework.serialize_type(element_type) sequence_proto = executor_pb2.Value.Sequence( element_type=element_type_proto, element=elements_proto ) @@ -202,11 +199,11 @@ def _flatten(value, type_spec): return value_proto, type_spec -@tracing.trace +@federated_language.framework.trace def _serialize_struct_type( struct_typed_value: object, - type_spec: computation_types.StructType, -) -> tuple[executor_pb2.Value, computation_types.StructType]: + type_spec: federated_language.StructType, +) -> tuple[executor_pb2.Value, federated_language.StructType]: """Serializes a value of tuple type.""" value_structure = structure.from_container(struct_typed_value) if len(value_structure) != len(type_spec): @@ -232,10 +229,10 @@ def _serialize_struct_type( return value_proto, type_spec -@tracing.trace +@federated_language.framework.trace def _serialize_federated_value( - federated_value: object, type_spec: computation_types.FederatedType -) -> tuple[executor_pb2.Value, computation_types.FederatedType]: + federated_value: object, type_spec: federated_language.FederatedType +) -> tuple[executor_pb2.Value, federated_language.FederatedType]: """Serializes a value of federated type.""" if type_spec.all_equal: value = [federated_value] @@ -248,15 +245,15 @@ def _serialize_federated_value( type_spec.member.check_assignable_from(it_type) value_proto.federated.value.append(federated_value_proto) value_proto.federated.type.CopyFrom( - type_serialization.serialize_type(type_spec).federated + federated_language.framework.serialize_type(type_spec).federated ) return value_proto, type_spec -@tracing.trace +@federated_language.framework.trace def serialize_value( value: object, - type_spec: Optional[computation_types.Type] = None, + type_spec: Optional[federated_language.Type] = None, ) -> _SerializeReturnType: """Serializes a value into `executor_pb2.Value`. @@ -277,9 +274,9 @@ def serialize_value( """ if isinstance(value, computation_pb2.Computation): return _serialize_computation(value, type_spec) - elif isinstance(value, computation_impl.ConcreteComputation): + elif isinstance(value, federated_language.framework.ConcreteComputation): return _serialize_computation( - computation_impl.ConcreteComputation.get_proto(value), + federated_language.framework.ConcreteComputation.get_proto(value), executor_utils.reconcile_value_with_type_spec(value, type_spec), ) elif type_spec is None: @@ -288,13 +285,13 @@ def serialize_value( 'is not a TFF computation. Asked to serialized value {v} ' ' of type {t} with None type spec.'.format(v=value, t=type(value)) ) - elif isinstance(type_spec, computation_types.TensorType): + elif isinstance(type_spec, federated_language.TensorType): return _serialize_tensor_value(value, type_spec) - elif isinstance(type_spec, computation_types.SequenceType): + elif isinstance(type_spec, federated_language.SequenceType): return _serialize_sequence_value(value, type_spec) - elif isinstance(type_spec, computation_types.StructType): + elif isinstance(type_spec, federated_language.StructType): return _serialize_struct_type(value, type_spec) - elif isinstance(type_spec, computation_types.FederatedType): + elif isinstance(type_spec, federated_language.FederatedType): return _serialize_federated_value(value, type_spec) else: raise ValueError( @@ -305,24 +302,28 @@ def serialize_value( ) -@tracing.trace +@federated_language.framework.trace def _deserialize_computation( value_proto: executor_pb2.Value, ) -> _DeserializeReturnType: """Deserializes a TFF computation.""" which_value = value_proto.computation.WhichOneof('computation') if which_value == 'literal': - value = array.from_proto(value_proto.computation.literal.value) + value = federated_language.array_from_proto( + value_proto.computation.literal.value + ) else: value = value_proto.computation - type_spec = type_serialization.deserialize_type(value_proto.computation.type) + type_spec = federated_language.framework.deserialize_type( + value_proto.computation.type + ) return value, type_spec -@tracing.trace +@federated_language.framework.trace def _deserialize_tensor_value( array_proto: array_pb2.Array, - type_hint: Optional[computation_types.TensorType] = None, + type_hint: Optional[federated_language.TensorType] = None, ) -> _DeserializeReturnType: """Deserializes a tensor value from `.Value`. @@ -338,33 +339,33 @@ def _deserialize_tensor_value( if type_hint is not None: type_spec = type_hint else: - dtype = dtype_utils.from_proto(array_proto.dtype) - shape = array_shape.from_proto(array_proto.shape) - type_spec = computation_types.TensorType(dtype, shape) + dtype = federated_language.dtype_from_proto(array_proto.dtype) + shape = federated_language.array_shape_from_proto(array_proto.shape) + type_spec = federated_language.TensorType(dtype, shape) # Repeated fields are used for strings and constants to maintain compatibility # with other external environments. if ( - array_shape.is_shape_scalar(type_spec.shape) + federated_language.array_shape_is_scalar(type_spec.shape) or type_spec.dtype.type is np.str_ ): - value = array.from_proto(array_proto) + value = federated_language.array_from_proto(array_proto) else: - value = array.from_proto_content(array_proto) + value = federated_language.array_from_proto_content(array_proto) return value, type_spec -@tracing.trace +@federated_language.framework.trace def _deserialize_sequence_value( sequence_proto: executor_pb2.Value.Sequence, - type_hint: Optional[computation_types.SequenceType] = None, + type_hint: Optional[federated_language.SequenceType] = None, ) -> _DeserializeReturnType: """Deserializes a value of sequence type. Args: sequence_proto: `Sequence` protocol buffer message. - type_hint: A `computation_types.Type` that hints at what the value type + type_hint: A `federated_language.Type` that hints at what the value type should be for executors that only return values. If the `sequence_value_proto.element_type` field was not set, the `type_hint` is used instead. @@ -375,7 +376,7 @@ def _deserialize_sequence_value( if type_hint is not None: element_type = type_hint.element else: - element_type = type_serialization.deserialize_type( + element_type = federated_language.framework.deserialize_type( sequence_proto.element_type ) @@ -390,16 +391,18 @@ def _deserialize_sequence_value( value, _ = _deserialize_tensor_value(array_proto, type_spec) flat_element.append(value) - if isinstance(element_type, computation_types.TensorType): + if isinstance(element_type, federated_language.TensorType): if len(flat_element) != 1: raise ValueError( f'Expected `flat_element` of type {element_type} to have only one' f' element, found {len(flat_element)}.' ) element, *_ = flat_element - elif isinstance(element_type, computation_types.StructType): + elif isinstance(element_type, federated_language.StructType): element = structure.pack_sequence_as(element_type, flat_element) - element = type_conversions.type_to_py_container(element, element_type) + element = federated_language.framework.type_to_py_container( + element, element_type + ) else: raise ValueError( 'Expected `element_type` to be either a `tff.StructType` or a' @@ -407,14 +410,14 @@ def _deserialize_sequence_value( ) elements.append(element) - type_spec = computation_types.SequenceType(element_type) + type_spec = federated_language.SequenceType(element_type) return elements, type_spec -@tracing.trace +@federated_language.framework.trace def _deserialize_struct_value( value_proto: executor_pb2.Value, - type_hint: Optional[computation_types.Type] = None, + type_hint: Optional[federated_language.Type] = None, ) -> _DeserializeReturnType: """Deserializes a value of struct type.""" val_elems = [] @@ -428,20 +431,23 @@ def _deserialize_struct_value( e_val, e_type = deserialize_value(e.value, e_type) val_elems.append((name, e_val)) type_elems.append((name, e_type) if name else e_type) - return (structure.Struct(val_elems), computation_types.StructType(type_elems)) + return ( + structure.Struct(val_elems), + federated_language.StructType(type_elems), + ) def _ensure_deserialized_types_compatible( - previous_type: Optional[computation_types.Type], - next_type: computation_types.Type, -) -> computation_types.Type: + previous_type: Optional[federated_language.Type], + next_type: federated_language.Type, +) -> federated_language.Type: """Ensures one of `previous_type` or `next_type` is assignable to the other. Returns the type which is assignable from the other. Args: - previous_type: Instance of `computation_types.Type` or `None`. - next_type: Instance of `computation_types.Type`. + previous_type: Instance of `federated_language.Type` or `None`. + next_type: Instance of `federated_language.Type`. Returns: The supertype of `previous_type` and `next_type`. @@ -463,10 +469,10 @@ def _ensure_deserialized_types_compatible( ) -@tracing.trace +@federated_language.framework.trace def _deserialize_federated_value( value_proto: executor_pb2.Value, - type_hint: Optional[computation_types.Type] = None, + type_hint: Optional[federated_language.Type] = None, ) -> _DeserializeReturnType: """Deserializes a value of federated type.""" if not value_proto.federated.value: @@ -497,9 +503,11 @@ def _deserialize_federated_value( item_value, next_item_type = deserialize_value(item, item_type_hint) item_type = _ensure_deserialized_types_compatible(item_type, next_item_type) value.append(item_value) - type_spec = computation_types.FederatedType( + type_spec = federated_language.FederatedType( item_type, - placement=placements.uri_to_placement_literal(placement_uri), + placement=federated_language.framework.uri_to_placement_literal( + placement_uri + ), all_equal=all_equal, ) if all_equal: @@ -507,10 +515,10 @@ def _deserialize_federated_value( return value, type_spec -@tracing.trace +@federated_language.framework.trace def deserialize_value( value_proto: executor_pb2.Value, - type_hint: Optional[computation_types.Type] = None, + type_hint: Optional[federated_language.Type] = None, ) -> _DeserializeReturnType: """Deserializes a value (of any type) from `executor_pb2.Value`. @@ -537,7 +545,7 @@ def deserialize_value( which_value = value_proto.WhichOneof('value') if which_value == 'array': if type_hint is not None and not isinstance( - type_hint, computation_types.TensorType + type_hint, federated_language.TensorType ): raise ValueError(f'Expected a `tff.TensorType`, found {type_hint}.') return _deserialize_tensor_value(value_proto.array, type_hint) @@ -556,7 +564,7 @@ def deserialize_value( def serialize_cardinalities( - cardinalities: Mapping[placements.PlacementLiteral, int], + cardinalities: Mapping[federated_language.framework.PlacementLiteral, int], ) -> list[executor_pb2.Cardinality]: serialized_cardinalities = [] for placement, cardinality in cardinalities.items(): @@ -570,10 +578,10 @@ def serialize_cardinalities( def deserialize_cardinalities( serialized_cardinalities: Collection[executor_pb2.Cardinality], -) -> dict[placements.PlacementLiteral, int]: +) -> dict[federated_language.framework.PlacementLiteral, int]: cardinalities = {} for cardinality_spec in serialized_cardinalities: - literal = placements.uri_to_placement_literal( + literal = federated_language.framework.uri_to_placement_literal( cardinality_spec.placement.uri ) cardinalities[literal] = cardinality_spec.cardinality diff --git a/tensorflow_federated/python/core/impl/executors/value_serialization_test.py b/tensorflow_federated/python/core/impl/executors/value_serialization_test.py index c12c593113..160a08621c 100644 --- a/tensorflow_federated/python/core/impl/executors/value_serialization_test.py +++ b/tensorflow_federated/python/core/impl/executors/value_serialization_test.py @@ -17,20 +17,16 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language +from federated_language.proto import computation_pb2 import numpy as np -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.proto.v0 import executor_pb2 from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.impl.executors import value_serialization -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_serialization -from tensorflow_federated.python.core.impl.types import type_test_utils # Convenience aliases. -TensorType = computation_types.TensorType +TensorType = federated_language.TensorType class _TestNamedTuple(NamedTuple): @@ -39,7 +35,7 @@ class _TestNamedTuple(NamedTuple): c: int -@federated_computation.federated_computation(np.int32) +@federated_language.federated_computation(np.int32) def _identity(x): return x @@ -66,9 +62,13 @@ def test_serialize_deserialize_tensor_value_without_hint( value_proto, value_type = value_serialization.serialize_value( x, serialize_type_spec ) - type_test_utils.assert_types_identical(value_type, serialize_type_spec) + federated_language.framework.assert_types_identical( + value_type, serialize_type_spec + ) y, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_types_identical(type_spec, serialize_type_spec) + federated_language.framework.assert_types_identical( + type_spec, serialize_type_spec + ) self.assertEqual(y.dtype, serialize_type_spec.dtype) if isinstance(y, (np.ndarray, np.generic)): np.testing.assert_array_equal(y, x) @@ -81,9 +81,13 @@ def test_serialize_deserialize_tensor_value_unknown_shape_without_hint(self): value_proto, value_type = value_serialization.serialize_value( x, serialize_type_spec ) - type_test_utils.assert_type_assignable_from(value_type, serialize_type_spec) + federated_language.framework.assert_type_assignable_from( + value_type, serialize_type_spec + ) y, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_type_assignable_from(serialize_type_spec, type_spec) + federated_language.framework.assert_type_assignable_from( + serialize_type_spec, type_spec + ) self.assertEqual(y.dtype, serialize_type_spec.dtype) if isinstance(y, (np.ndarray, np.generic)): np.testing.assert_array_equal(y, x, strict=True) @@ -97,11 +101,13 @@ def test_serialize_deserialize_tensor_value_with_hint( value_proto, value_type = value_serialization.serialize_value( x, serialize_type_spec ) - type_test_utils.assert_types_identical(value_type, serialize_type_spec) + federated_language.framework.assert_types_identical( + value_type, serialize_type_spec + ) y, deserialize_type_spec = value_serialization.deserialize_value( value_proto, type_hint=serialize_type_spec ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( deserialize_type_spec, serialize_type_spec ) self.assertEqual(y.dtype, serialize_type_spec.dtype) @@ -152,11 +158,11 @@ def test_serialize_deserialize_string_value( value_proto, value_type = value_serialization.serialize_value( value, type_spec ) - type_test_utils.assert_types_identical(value_type, type_spec) + federated_language.framework.assert_types_identical(value_type, type_spec) result, result_type = value_serialization.deserialize_value( value_proto, type_spec ) - type_test_utils.assert_types_identical(result_type, type_spec) + federated_language.framework.assert_types_identical(result_type, type_spec) if isinstance(result, (np.ndarray, np.generic)): np.testing.assert_array_equal(result, expected_value, strict=True) @@ -173,11 +179,13 @@ def test_serialize_deserialize_tensor_value_with_nontrivial_shape(self): value_proto, value_type = value_serialization.serialize_value( x, TensorType(np.int32, [3]) ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( value_type, TensorType(np.int32, [3]) ) y, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_types_identical(type_spec, TensorType(np.int32, [3])) + federated_language.framework.assert_types_identical( + type_spec, TensorType(np.int32, [3]) + ) if isinstance(y, (np.ndarray, np.generic)): np.testing.assert_array_equal(y, x, strict=True) else: @@ -193,31 +201,31 @@ def test_serialize_struct_with_type_element_mismatch(self): ), ): value_serialization.serialize_value( - x, computation_types.StructType([('a', np.int32), ('b', np.int32)]) + x, federated_language.StructType([('a', np.int32), ('b', np.int32)]) ) def test_serialize_sequence_raises_type_error_with_invalid_type_spec(self): value = [1, 2, 3] - type_spec = computation_types.SequenceType(np.float32) + type_spec = federated_language.SequenceType(np.float32) with self.assertRaisesRegex(TypeError, 'Failed to serialize the value'): value_serialization.serialize_value(value, type_spec) @parameterized.named_parameters( - ('scalar', [1, 2, 3], computation_types.SequenceType(np.int32)), + ('scalar', [1, 2, 3], federated_language.SequenceType(np.int32)), ( 'tuple', [(1, 2, 3), (4, 5, 6), (7, 8, 9)], - computation_types.SequenceType([np.int32, np.int32, np.int32]), + federated_language.SequenceType([np.int32, np.int32, np.int32]), ), ( 'tuple_empty', [(), (), ()], - computation_types.SequenceType([]), + federated_language.SequenceType([]), ), ( 'tuple_singleton', [(1,), (2,), (3,)], - computation_types.SequenceType([np.int32]), + federated_language.SequenceType([np.int32]), ), ( 'dict', @@ -226,7 +234,7 @@ def test_serialize_sequence_raises_type_error_with_invalid_type_spec(self): {'a': 4, 'b': 5, 'c': 6}, {'a': 7, 'b': 8, 'c': 9}, ], - computation_types.SequenceType([ + federated_language.SequenceType([ ('a', np.int32), ('b', np.int32), ('c', np.int32), @@ -239,8 +247,8 @@ def test_serialize_sequence_raises_type_error_with_invalid_type_spec(self): _TestNamedTuple(4, 5, 6), _TestNamedTuple(7, 8, 9), ], - computation_types.SequenceType( - computation_types.StructWithPythonType( + federated_language.SequenceType( + federated_language.StructWithPythonType( [ ('a', np.int32), ('b', np.int32), @@ -255,16 +263,16 @@ def test_serialize_deserialize_sequence(self, value, type_spec): value_proto, value_type = value_serialization.serialize_value( value, type_spec ) - type_test_utils.assert_types_identical(value_type, type_spec) + federated_language.framework.assert_types_identical(value_type, type_spec) result, result_type = value_serialization.deserialize_value( value_proto, type_spec ) - type_test_utils.assert_types_equivalent(result_type, type_spec) + federated_language.framework.assert_types_equivalent(result_type, type_spec) self.assertEqual(result, value) def test_serialize_deserialize_tensor_value_with_bad_shape(self): value = np.array([10, 20, 30], np.int32) - type_spec = computation_types.TensorType(np.int32) + type_spec = federated_language.TensorType(np.int32) with self.assertRaises(TypeError): value_serialization.serialize_value(value, type_spec) @@ -272,21 +280,21 @@ def test_serialize_deserialize_tensor_value_with_bad_shape(self): def test_serialize_deserialize_computation_value(self): value_proto, value_type = value_serialization.serialize_value(_identity) self.assertEqual(value_proto.WhichOneof('value'), 'computation') - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( value_type, - computation_types.FunctionType(parameter=np.int32, result=np.int32), + federated_language.FunctionType(parameter=np.int32, result=np.int32), ) _, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( type_spec, - computation_types.FunctionType(parameter=np.int32, result=np.int32), + federated_language.FunctionType(parameter=np.int32, result=np.int32), ) def test_serialize_deserialize_nested_tuple_value_with_names(self): x = collections.OrderedDict( a=10, b=[20, 30], c=collections.OrderedDict(d=40) ) - x_type = computation_types.StructType( + x_type = federated_language.StructType( collections.OrderedDict( a=np.int32, b=[np.int32, np.int32], @@ -294,41 +302,48 @@ def test_serialize_deserialize_nested_tuple_value_with_names(self): ) ) value_proto, value_type = value_serialization.serialize_value(x, x_type) - type_test_utils.assert_types_identical(value_type, x_type) + federated_language.framework.assert_types_identical(value_type, x_type) y, type_spec = value_serialization.deserialize_value(value_proto) # Don't assert on the Python container since it is lost in serialization. - type_test_utils.assert_types_equivalent(type_spec, x_type) + federated_language.framework.assert_types_equivalent(type_spec, x_type) self.assertEqual(y, structure.from_container(x, recursive=True)) def test_serialize_deserialize_nested_tuple_value_without_names(self): x = (10, 20) - x_type = computation_types.StructType([np.int32, np.int32]) + x_type = federated_language.StructType([np.int32, np.int32]) value_proto, value_type = value_serialization.serialize_value(x, x_type) - type_test_utils.assert_types_identical(value_type, x_type) + federated_language.framework.assert_types_identical(value_type, x_type) y, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_types_equivalent(type_spec, x_type) + federated_language.framework.assert_types_equivalent(type_spec, x_type) self.assertEqual(y, structure.from_container((10, 20))) def test_serialize_deserialize_federated_at_clients(self): x = [10, 20] - x_type = computation_types.FederatedType(np.int32, placements.CLIENTS) + x_type = federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) value_proto, value_type = value_serialization.serialize_value(x, x_type) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( value_type, - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) y, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_types_identical( - type_spec, computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.framework.assert_types_identical( + type_spec, + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) self.assertEqual(y, [10, 20]) def test_deserialize_federated_value_with_unset_member_type(self): x = 10 - x_type = computation_types.TensorType(np.int32) + x_type = federated_language.TensorType(np.int32) member_proto, _ = value_serialization.serialize_value(x, x_type) - fully_specified_type_at_clients = type_serialization.serialize_type( - computation_types.FederatedType(np.int32, placements.CLIENTS) + fully_specified_type_at_clients = ( + federated_language.framework.serialize_type( + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + ) ) unspecified_member_federated_type = computation_pb2.FederatedType( @@ -344,9 +359,9 @@ def test_deserialize_federated_value_with_unset_member_type(self): deserialized_federated_value, deserialized_type_spec = ( value_serialization.deserialize_value(federated_value_proto) ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( deserialized_type_spec, - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) self.assertEqual(deserialized_federated_value, [10]) @@ -354,13 +369,17 @@ def test_deserialize_federated_value_with_incompatible_member_types_raises( self, ): x = 10 - x_type = computation_types.TensorType(np.int32) + x_type = federated_language.TensorType(np.int32) int_member_proto, _ = value_serialization.serialize_value(x, x_type) y = 10.0 - y_type = computation_types.TensorType(np.float32) + y_type = federated_language.TensorType(np.float32) float_member_proto, _ = value_serialization.serialize_value(y, y_type) - fully_specified_type_at_clients = type_serialization.serialize_type( - computation_types.FederatedType(np.int32, placements.CLIENTS) + fully_specified_type_at_clients = ( + federated_language.framework.serialize_type( + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) + ) ) unspecified_member_federated_type = computation_pb2.FederatedType( @@ -387,36 +406,38 @@ def test_deserialize_federated_all_equal_value_takes_first_element(self): value=[tensor_value_pb] * num_clients, type=computation_pb2.FederatedType( placement=computation_pb2.PlacementSpec( - value=computation_pb2.Placement(uri=placements.CLIENTS.uri) + value=computation_pb2.Placement( + uri=federated_language.CLIENTS.uri + ) ) ), ) ) - all_equal_clients_type_hint = computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True + all_equal_clients_type_hint = federated_language.FederatedType( + np.int32, federated_language.CLIENTS, all_equal=True ) deserialized_value, deserialized_type = ( value_serialization.deserialize_value( value_pb, all_equal_clients_type_hint ) ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( deserialized_type, all_equal_clients_type_hint ) self.assertEqual(deserialized_value, 10) def test_deserialize_federated_value_promotes_types(self): x = [10] - smaller_type = computation_types.StructType([(None, np.int32)]) + smaller_type = federated_language.StructType([(None, np.int32)]) smaller_type_member_proto, _ = value_serialization.serialize_value( x, smaller_type ) - larger_type = computation_types.StructType([('a', np.int32)]) + larger_type = federated_language.StructType([('a', np.int32)]) larger_type_member_proto, _ = value_serialization.serialize_value( x, larger_type ) - type_at_clients = type_serialization.serialize_type( - computation_types.FederatedType(np.int32, placements.CLIENTS) + type_at_clients = federated_language.framework.serialize_type( + federated_language.FederatedType(np.int32, federated_language.CLIENTS) ) unspecified_member_federated_type = computation_pb2.FederatedType( @@ -432,20 +453,25 @@ def test_deserialize_federated_value_promotes_types(self): _, deserialized_type_spec = value_serialization.deserialize_value( federated_value_proto ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( deserialized_type_spec, - computation_types.FederatedType(larger_type, placements.CLIENTS), + federated_language.FederatedType( + larger_type, federated_language.CLIENTS + ), ) def test_serialize_deserialize_federated_at_server(self): x = 10 - x_type = computation_types.FederatedType(np.int32, placements.SERVER) + x_type = federated_language.FederatedType( + np.int32, federated_language.SERVER + ) value_proto, value_type = value_serialization.serialize_value(x, x_type) - type_test_utils.assert_types_identical( - value_type, computation_types.FederatedType(np.int32, placements.SERVER) + federated_language.framework.assert_types_identical( + value_type, + federated_language.FederatedType(np.int32, federated_language.SERVER), ) y, type_spec = value_serialization.deserialize_value(value_proto) - type_test_utils.assert_types_identical(type_spec, x_type) + federated_language.framework.assert_types_identical(type_spec, x_type) self.assertEqual(y, 10) @@ -455,8 +481,8 @@ def test_serialize_deserialize_clients_and_server_cardinalities_roundtrip( self, ): client_and_server_cardinalities = { - placements.CLIENTS: 10, - placements.SERVER: 1, + federated_language.CLIENTS: 10, + federated_language.SERVER: 1, } cardinalities_list = value_serialization.serialize_cardinalities( client_and_server_cardinalities @@ -471,7 +497,7 @@ def test_serialize_deserialize_clients_and_server_cardinalities_roundtrip( ) def test_serialize_deserialize_clients_alone(self): - client_cardinalities = {placements.CLIENTS: 10} + client_cardinalities = {federated_language.CLIENTS: 10} cardinalities_list = value_serialization.serialize_cardinalities( client_cardinalities ) diff --git a/tensorflow_federated/python/core/impl/federated_context/BUILD b/tensorflow_federated/python/core/impl/federated_context/BUILD deleted file mode 100644 index 8f08fee1db..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/BUILD +++ /dev/null @@ -1,216 +0,0 @@ -load("@rules_python//python:defs.bzl", "py_library", "py_test") - -package( - default_applicable_licenses = ["//:package_license"], - default_visibility = [ - ":federated_context_packages", - "//tensorflow_federated/python/core/impl:impl_users", - "//tensorflow_federated/python/core/impl/executors:executors_packages", - ], -) - -package_group( - name = "federated_context_packages", - packages = ["//tensorflow_federated/python/core/impl/federated_context/..."], -) - -licenses(["notice"]) - -py_library( - name = "federated_context", - srcs = ["__init__.py"], - visibility = ["//tools/python_package:python_package_tool"], -) - -py_library(name = "data") - -py_library( - name = "federated_computation", - srcs = ["federated_computation.py"], - deps = [ - ":federated_computation_utils", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - ], -) - -py_test( - name = "federated_computation_test", - size = "small", - srcs = ["federated_computation_test.py"], - deps = [ - ":federated_computation", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/context_stack:get_context_stack", - "//tensorflow_federated/python/core/impl/context_stack:runtime_error_context", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_library( - name = "federated_computation_context", - srcs = ["federated_computation_context.py"], - deps = [ - ":value_impl", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/context_stack:symbol_binding_context", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", - ], -) - -py_test( - name = "federated_computation_context_test", - size = "small", - srcs = ["federated_computation_context_test.py"], - deps = [ - ":federated_computation", - ":federated_computation_context", - ":value_impl", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:computation_factory", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "federated_computation_utils", - srcs = ["federated_computation_utils.py"], - deps = [ - ":federated_computation_context", - ":value_impl", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - ], -) - -py_test( - name = "federated_computation_utils_test", - size = "small", - srcs = ["federated_computation_utils_test.py"], - deps = [ - ":federated_computation_utils", - "//tensorflow_federated/python/core/impl/computation:computation_wrapper", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_library( - name = "intrinsics", - srcs = ["intrinsics.py"], - deps = [ - ":value_impl", - ":value_utils", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/context_stack:symbol_binding_context", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - "//tensorflow_federated/python/core/impl/types:type_factory", - ], -) - -py_test( - name = "intrinsics_test", - srcs = ["intrinsics_test.py"], - deps = [ - ":federated_computation_context", - ":intrinsics", - ":value_impl", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_test_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "value_impl", - srcs = ["value_impl.py"], - deps = [ - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:array", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/computation:function_utils", - "//tensorflow_federated/python/core/impl/computation:polymorphic_computation", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/context_stack:symbol_binding_context", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_conversions", - "//tensorflow_federated/python/core/impl/types:typed_object", - ], -) - -py_test( - name = "value_impl_test", - size = "small", - srcs = ["value_impl_test.py"], - deps = [ - ":federated_computation_context", - ":value_impl", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:computation_factory", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "value_utils", - srcs = ["value_utils.py"], - deps = [ - ":value_impl", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_block_factory", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_test( - name = "value_utils_test", - size = "small", - srcs = ["value_utils_test.py"], - deps = [ - ":federated_computation", - ":federated_computation_context", - ":value_impl", - ":value_utils", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:computation_factory", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) diff --git a/tensorflow_federated/python/core/impl/federated_context/__init__.py b/tensorflow_federated/python/core/impl/federated_context/__init__.py deleted file mode 100644 index aab63a6363..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Libraries for interacting with a federated context.""" diff --git a/tensorflow_federated/python/core/impl/federated_context/federated_computation.py b/tensorflow_federated/python/core/impl/federated_context/federated_computation.py deleted file mode 100644 index 16f7e884ab..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/federated_computation.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Definition of a federated computation.""" - -from typing import Optional - -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation_utils - - -def _federated_computation_wrapper_fn( - fn, - parameter_type, - unpack: Optional[bool], - name: Optional[str] = None, - **kwargs -): - """Wrapper function to plug orchestration logic into the TFF framework.""" - del kwargs # Unused. - if parameter_type is None: - parameter_name = None - else: - parameter_name = 'arg' - fn = function_utils.wrap_as_zero_or_one_arg_callable( - fn, parameter_type, unpack - ) - context_stack = context_stack_impl.context_stack - target_lambda, extra_type_spec = ( - federated_computation_utils.zero_or_one_arg_fn_to_building_block( - fn, - parameter_name, - parameter_type, - context_stack, - suggested_name=name, - ) - ) - return computation_impl.ConcreteComputation( - computation_proto=target_lambda.proto, - context_stack=context_stack, - annotated_type=extra_type_spec, - ) - - -federated_computation = computation_wrapper.ComputationWrapper( - _federated_computation_wrapper_fn -) -federated_computation.__doc__ = """Decorates/wraps Python functions as TFF federated/composite computations. - - The term *federated computation* as used here refers to any computation that - uses TFF programming abstractions. Examples of such computations may include - federated training or federated evaluation that involve both client-side and - server-side logic and involve network communication. However, this - decorator/wrapper can also be used to construct composite computations that - only involve local processing on a client or on a server. - - The main feature that distinguishes *federated computation* function bodies - in Python from the bodies of TensorFlow defuns is that whereas in the latter, - one slices and dices `tf.Tensor` instances using a variety of TensorFlow ops, - in the former one slices and dices `tff.Value` instances using TFF operators. - - The supported modes of usage are identical to those for - `tff.tensorflow.computation`. - - Example: - - ```python - @tff.federated_computation((tff.FunctionType(np.int32, np.int32), np.int32)) - def foo(f, x): - return f(f(x)) - ``` - - The above defines `foo` as a function that takes a tuple consisting of an - unary integer operator as the first element, and an integer as the second - element, and returns the result of applying the unary operator to the - integer twice. The body of `foo` does not contain federated communication - operators, but we define it with `tff.federated_computation` as it can be - used as building block in any section of TFF code (except inside sections - of pure TensorFlow logic). - - Args: - *args: Either a Python function, or TFF type spec, or both (function first), - or neither. See also `tff.tensorflow.computation` for an extended - documentation. - - Returns: - If invoked with a function as an argument, returns an instance of a TFF - computation constructed based on this function. If called without one, as - in the typical decorator style of usage, returns a callable that expects - to be called with the function definition supplied as a parameter. See - also `tff.tensorflow.computation` for an extended documentation. - """ diff --git a/tensorflow_federated/python/core/impl/federated_context/federated_computation_context.py b/tensorflow_federated/python/core/impl/federated_context/federated_computation_context.py deleted file mode 100644 index b1d09e508d..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/federated_computation_context.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""The implementation of a context to use in building federated computations.""" - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.context_stack import symbol_binding_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis - - -class FederatedComputationContext( - symbol_binding_context.SymbolBindingContext[ - building_blocks.ComputationBuildingBlock, - building_blocks.Reference, - ] -): - """The context for building federated computations. - - This context additionally holds a list of symbols which are bound to - `building_block.ComputationBuildingBlocks` during construction of - `tff.Values`, and which respect identical semantics to the binding of locals - in `building_blocks.Blocks`. - - Any `tff.Value` constructed in this context may add such a symbol binding, - and thereafter refer to the returned reference in place of the bound - computation. It is then the responsibility of the installer of this context - to ensure that the symbols bound during the `tff.Value` construction process - are appropriately packaged in the result. - """ - - def __init__(self, context_stack, suggested_name=None, parent=None): - """Creates this context. - - Args: - context_stack: The context stack to use. - suggested_name: The optional suggested name of the context, a string. It - may be modified to make it different from the names of any of the - ancestors on the context stack. - parent: The optional parent context. If not `None`, it must be an instance - of `FederatedComputationContext`. - """ - py_typecheck.check_type(context_stack, context_stack_base.ContextStack) - if suggested_name: - py_typecheck.check_type(suggested_name, str) - suggested_name = str(suggested_name) - else: - suggested_name = 'FEDERATED' - if parent is not None: - py_typecheck.check_type(parent, FederatedComputationContext) - ancestor = parent - ancestor_names = set() - while ancestor is not None: - ancestor_names.add(ancestor.name) - ancestor = ancestor.parent - name = suggested_name - name_count = 0 - while name in ancestor_names: - name_count = name_count + 1 - name = '{}_{}'.format(suggested_name, name_count) - self._context_stack = context_stack - self._parent = parent - self._name = name - self._symbol_bindings = [] - self._next_symbol_val = 0 - - @property - def name(self): - return self._name - - @property - def parent(self): - return self._parent - - def bind_computation_to_reference( - self, comp: building_blocks.ComputationBuildingBlock - ) -> building_blocks.Reference: - """Binds a computation to a symbol, returns a reference to this binding.""" - name = 'fc_{name}_symbol_{val}'.format( - name=self._name, val=self._next_symbol_val - ) - self._next_symbol_val += 1 - self._symbol_bindings.append((name, comp)) - ref = building_blocks.Reference(name, comp.type_signature) - return ref - - @property - def symbol_bindings( - self, - ) -> list[tuple[str, building_blocks.ComputationBuildingBlock]]: - return self._symbol_bindings - - def invoke(self, comp, arg): - fn = value_impl.to_value(comp, type_spec=None) - tys = fn.type_signature - py_typecheck.check_type(tys, computation_types.FunctionType) - if arg is not None: - if tys.parameter is None: # pytype: disable=attribute-error - raise ValueError( - 'A computation of type {} does not expect any arguments, but got ' - 'an argument {}.'.format(tys, arg) - ) - arg = value_impl.to_value( - arg, - type_spec=tys.parameter, # pytype: disable=attribute-error - zip_if_needed=True, - ) - type_analysis.check_type(arg, tys.parameter) # pytype: disable=attribute-error - ret_val = fn(arg) - else: - if tys.parameter is not None: # pytype: disable=attribute-error - raise ValueError( - 'A computation of type {} expects an argument of type {}, but got ' - ' no argument.'.format(tys, tys.parameter) # pytype: disable=attribute-error - ) - ret_val = fn() - type_analysis.check_type(ret_val, tys.result) # pytype: disable=attribute-error - return ret_val diff --git a/tensorflow_federated/python/core/impl/federated_context/federated_computation_context_test.py b/tensorflow_federated/python/core/impl/federated_context/federated_computation_context_test.py deleted file mode 100644 index 41041d540b..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/federated_computation_context_test.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -class FederatedComputationContextTest(absltest.TestCase): - - def test_invoke_returns_value_with_correct_type(self): - tensor_type = computation_types.TensorType(np.int32) - computation_proto = computation_factory.create_lambda_identity(tensor_type) - computation = computation_impl.ConcreteComputation( - computation_proto=computation_proto, - context_stack=context_stack_impl.context_stack, - ) - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - - with context_stack_impl.context_stack.install(context): - result = context.invoke(computation, 1) - - self.assertIsInstance(result, value_impl.Value) - self.assertEqual(result.type_signature, tensor_type) - - def test_ingest_zips_value_when_necessary_to_match_federated_type(self): - # Expects `{}@C` - @federated_computation.federated_computation( - computation_types.FederatedType( - (np.int32, np.int32), placements.CLIENTS - ) - ) - def fn(_): - return () - - # This thing will be <{int}@C, {int}@C> - arg = building_blocks.Struct([ - building_blocks.Reference( - 'x', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - building_blocks.Reference( - 'y', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]) - - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - with context_stack_impl.context_stack.install(context): - fn(arg) - - def test_ingest_zips_federated_under_struct(self): - - @federated_computation.federated_computation( - computation_types.StructType([( - None, - collections.OrderedDict( - x=computation_types.FederatedType(np.int32, placements.CLIENTS), - y=computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - )]) - ) - def fn(_): - return () - - arg = building_blocks.Struct( - [ - building_blocks.Struct([ - building_blocks.Reference( - 'x', - computation_types.FederatedType( - np.int32, placements.CLIENTS - ), - ), - building_blocks.Reference( - 'y', - computation_types.FederatedType( - np.int32, placements.CLIENTS - ), - ), - ]) - ] - ) - - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - with context_stack_impl.context_stack.install(context): - fn(arg) - - def test_construction_populates_name(self): - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - self.assertEqual(context.name, 'FEDERATED') - - def test_suggested_name_populates_name_attribute(self): - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack, suggested_name='FOO' - ) - self.assertEqual(context.name, 'FOO') - - def test_child_name_doesnt_conflict(self): - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack, suggested_name='FOO' - ) - self.assertEqual(context.name, 'FOO') - context2 = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack, suggested_name='FOO', parent=context - ) - self.assertEqual(context2.name, 'FOO_1') - context3 = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack, suggested_name='FOO', parent=context2 - ) - self.assertEqual(context3.name, 'FOO_2') - - def test_parent_populated_correctly(self): - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - context2 = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack, parent=context - ) - self.assertIs(context2.parent, context) - - def test_bind_single_computation_to_reference(self): - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - data = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - ref = context.bind_computation_to_reference(data) - symbol_bindings = context.symbol_bindings - bound_symbol_name = symbol_bindings[0][0] - - self.assertIsInstance(ref, building_blocks.Reference) - self.assertEqual(ref.type_signature, data.type_signature) - self.assertLen(symbol_bindings, 1) - self.assertEqual(bound_symbol_name, ref.name) - - def test_bind_two_computations_to_reference(self): - context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - lit = building_blocks.Literal(1, computation_types.TensorType(np.int32)) - float_data = building_blocks.Literal( - 2.0, computation_types.TensorType(np.float32) - ) - ref1 = context.bind_computation_to_reference(lit) - ref2 = context.bind_computation_to_reference(float_data) - symbol_bindings = context.symbol_bindings - - self.assertIsInstance(ref1, building_blocks.Reference) - self.assertIsInstance(ref2, building_blocks.Reference) - - self.assertEqual(ref1.type_signature, lit.type_signature) - self.assertEqual(ref2.type_signature, float_data.type_signature) - self.assertLen(symbol_bindings, 2) - self.assertEqual(symbol_bindings[0][0], ref1.name) - self.assertEqual(symbol_bindings[1][0], ref2.name) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/federated_context/federated_computation_test.py b/tensorflow_federated/python/core/impl/federated_context/federated_computation_test.py deleted file mode 100644 index 283b7471ec..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/federated_computation_test.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.context_stack import get_context_stack -from tensorflow_federated.python.core.impl.context_stack import runtime_error_context -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types - - -class FederatedComputationWrapperTest(absltest.TestCase): - - def test_federated_computation_wrapper(self): - - @federated_computation.federated_computation( - (computation_types.FunctionType(np.int32, np.int32), np.int32) - ) - def foo(f, x): - return f(f(x)) - - self.assertIsInstance(foo, computation_impl.ConcreteComputation) - self.assertEqual( - str(foo.type_signature), '( int32),x=int32> -> int32)' - ) - - self.assertEqual( - str(foo.to_building_block()), - ( - '(foo_arg -> (let' - ' fc_foo_symbol_0=foo_arg[0](foo_arg[1]),fc_foo_symbol_1=foo_arg[0](fc_foo_symbol_0)' - ' in fc_foo_symbol_1))' - ), - ) - - def test_stackframes_in_errors(self): - class DummyError(RuntimeError): - pass - - with self.assertRaises(DummyError): - @federated_computation.federated_computation - def _(): - raise DummyError() - - def test_empty_tuple_arg(self): - - @federated_computation.federated_computation( - computation_types.StructType([]) - ) - def foo(x): - return x - - self.assertIsInstance(foo, computation_impl.ConcreteComputation) - self.assertEqual(str(foo.type_signature), '(<> -> <>)') - - self.assertEqual(str(foo.to_building_block()), '(foo_arg -> foo_arg)') - - def test_stack_resets_on_none_returned(self): - stack = get_context_stack.get_context_stack() - self.assertIsInstance( - stack.current, runtime_error_context.RuntimeErrorContext - ) - - with self.assertRaises(computation_wrapper.ComputationReturnedNoneError): - @federated_computation.federated_computation() - def _(): - pass - - self.assertIsInstance( - stack.current, runtime_error_context.RuntimeErrorContext - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/federated_context/federated_computation_utils.py b/tensorflow_federated/python/core/impl/federated_context/federated_computation_utils.py deleted file mode 100644 index f6ccd39899..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/federated_computation_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# limitations under the License. -"""Helpers for creating larger structures out of computing building blocks.""" - -from typing import Optional - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.context_stack import context_stack_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_conversions - - -def zero_or_one_arg_fn_to_building_block( - fn, - parameter_name: Optional[str], - parameter_type: Optional[computation_types.Type], - context_stack: context_stack_base.ContextStack, - suggested_name: Optional[str] = None, -) -> tuple[ - building_blocks.ComputationBuildingBlock, computation_types.FunctionType -]: - """Converts a zero- or one-argument `fn` into a computation building block. - - Args: - fn: A function with 0 or 1 arguments that contains orchestration logic, - i.e., that expects zero or one `values_base.Value` and returns a result - convertible to the same. - parameter_name: The name of the parameter, or `None` if there is't any. - parameter_type: The `tff.Type` of the parameter, or `None` if there's none. - context_stack: The context stack to use. - suggested_name: The optional suggested name to use for the federated context - that will be used to serialize this function's body (ideally the name of - the underlying Python function). It might be modified to avoid conflicts. - - Returns: - A tuple of `(building_blocks.ComputationBuildingBlock, - computation_types.Type)`, where the first element contains the logic from - `fn`, and the second element contains potentially annotated type information - for the result of `fn`. - - Raises: - ValueError: if `fn` is incompatible with `parameter_type`. - """ - py_typecheck.check_type(context_stack, context_stack_base.ContextStack) - if suggested_name is not None: - py_typecheck.check_type(suggested_name, str) - if isinstance( - context_stack.current, - federated_computation_context.FederatedComputationContext, - ): - parent_context = context_stack.current - else: - parent_context = None - context = federated_computation_context.FederatedComputationContext( - context_stack, suggested_name=suggested_name, parent=parent_context - ) - if parameter_name is not None: - py_typecheck.check_type(parameter_name, str) - parameter_name = '{}_{}'.format(context.name, str(parameter_name)) - with context_stack.install(context): - if parameter_type is not None: - result = fn( - value_impl.Value( - building_blocks.Reference(parameter_name, parameter_type), - ) - ) - else: - result = fn() - if result is None: - raise computation_wrapper.ComputationReturnedNoneError(fn) - annotated_result_type = type_conversions.infer_type(result) - result = value_impl.to_value(result, type_spec=annotated_result_type) - result_comp = result.comp - symbols_bound_in_context = context_stack.current.symbol_bindings - if symbols_bound_in_context: - result_comp = building_blocks.Block( - local_symbols=symbols_bound_in_context, result=result_comp - ) - annotated_type = computation_types.FunctionType( - parameter_type, annotated_result_type - ) - return ( - building_blocks.Lambda(parameter_name, parameter_type, result_comp), - annotated_type, - ) diff --git a/tensorflow_federated/python/core/impl/federated_context/federated_computation_utils_test.py b/tensorflow_federated/python/core/impl/federated_context/federated_computation_utils_test.py deleted file mode 100644 index fb7a338130..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/federated_computation_utils_test.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.core.impl.computation import computation_wrapper -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation_utils -from tensorflow_federated.python.core.impl.types import computation_types - -TestNamedTuple = collections.namedtuple('TestTuple', ['x', 'y']) - - -class ZeroOrOneArgFnToBuildingBlockTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'nested_fn_same', - lambda f, x: f(f(x)), - computation_types.StructType([ - ('f', computation_types.FunctionType(np.int32, np.int32)), - ('x', np.int32), - ]), - ( - '(FEDERATED_foo -> (let ' - 'fc_FEDERATED_symbol_0=FEDERATED_foo.f(FEDERATED_foo.x),' - 'fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0)' - ' in fc_FEDERATED_symbol_1))' - ), - ), - ( - 'nested_fn_different', - lambda f, g, x: f(g(x)), - computation_types.StructType([ - ('f', computation_types.FunctionType(np.int32, np.int32)), - ('g', computation_types.FunctionType(np.int32, np.int32)), - ('x', np.int32), - ]), - ( - '(FEDERATED_foo -> (let ' - 'fc_FEDERATED_symbol_0=FEDERATED_foo.g(FEDERATED_foo.x),' - 'fc_FEDERATED_symbol_1=FEDERATED_foo.f(fc_FEDERATED_symbol_0)' - ' in fc_FEDERATED_symbol_1))' - ), - ), - ( - 'selection', - lambda x: (x[1], x[0]), - computation_types.StructType([np.int32, np.int32]), - '(FEDERATED_foo -> )', - ), - ('constant', lambda: 'stuff', None, "( -> b'stuff')"), - ) - def test_returns_result(self, fn, parameter_type, fn_str): - parameter_name = 'foo' if parameter_type is not None else None - fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type) - result, _ = ( - federated_computation_utils.zero_or_one_arg_fn_to_building_block( - fn, parameter_name, parameter_type, context_stack_impl.context_stack - ) - ) - self.assertEqual(str(result), fn_str) - - @parameterized.named_parameters( - ( - 'tuple', - lambda x: (x[1], x[0]), - computation_types.StructType([np.int32, np.float32]), - computation_types.StructWithPythonType( - [ - (None, np.float32), - (None, np.int32), - ], - tuple, - ), - ), - ( - 'list', - lambda x: [x[1], x[0]], - computation_types.StructType([np.int32, np.float32]), - computation_types.StructWithPythonType( - [ - (None, np.float32), - (None, np.int32), - ], - list, - ), - ), - ( - 'odict', - lambda x: collections.OrderedDict([('A', x[1]), ('B', x[0])]), - computation_types.StructType([np.int32, np.float32]), - computation_types.StructWithPythonType( - [ - ('A', np.float32), - ('B', np.int32), - ], - collections.OrderedDict, - ), - ), - ( - 'namedtuple', - lambda x: TestNamedTuple(x=x[1], y=x[0]), - computation_types.StructType([np.int32, np.float32]), - computation_types.StructWithPythonType( - [ - ('x', np.float32), - ('y', np.int32), - ], - TestNamedTuple, - ), - ), - ) - def test_returns_result_with_py_container( - self, fn, parameter_type, expected_result_type - ): - parameter_name = 'foo' - fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type) - _, type_signature = ( - federated_computation_utils.zero_or_one_arg_fn_to_building_block( - fn, parameter_name, parameter_type, context_stack_impl.context_stack - ) - ) - self.assertIs(type(type_signature.result), type(expected_result_type)) - self.assertIs( - type_signature.result.python_container, - expected_result_type.python_container, - ) - self.assertEqual(type_signature.result, expected_result_type) - - def test_raises_value_error_with_none_result(self): - fn = lambda: None - parameter_type = None - fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type) - - with self.assertRaises(computation_wrapper.ComputationReturnedNoneError): - federated_computation_utils.zero_or_one_arg_fn_to_building_block( - fn, None, parameter_type, context_stack_impl.context_stack - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/federated_context/intrinsics.py b/tensorflow_federated/python/core/impl/federated_context/intrinsics.py deleted file mode 100644 index ea7931da87..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/intrinsics.py +++ /dev/null @@ -1,1110 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A factory of intrinsics for use in composing federated computations.""" - -from typing import NoReturn -import warnings - -import numpy as np - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import intrinsic_defs -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import symbol_binding_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.federated_context import value_utils -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.core.impl.types import type_factory - - -def _bind_comp_as_reference(comp): - fc_context = context_stack_impl.context_stack.current - if not isinstance(fc_context, symbol_binding_context.SymbolBindingContext): - raise context_base.ContextError( - f'Attempted to construct an intrinsic in context {fc_context} which ' - ' does not support binding references.' - ) - return fc_context.bind_computation_to_reference(comp) - - -def federated_aggregate( - value, zero, accumulate, merge, report -) -> value_impl.Value: - """Aggregates `value` from `tff.CLIENTS` to `tff.SERVER`. - - This generalized aggregation function admits multi-layered architectures that - involve one or more intermediate stages to handle scalable aggregation across - a very large number of participants. - - The multi-stage aggregation process is defined as follows: - - * Clients are organized into groups. Within each group, a set of all the - member constituents of `value` contributed by clients in the group are first - reduced using reduction operator `accumulate` with `zero` as the zero in the - algebra. If members of `value` are of type `T`, and `zero` (the result of - reducing an empty set) is of type `U`, the reduction operator `accumulate` - used at this stage should be of type `( -> U)`. The result of this - stage is a set of items of type `U`, one item for each group of clients. - - * Next, the `U`-typed items generated by the preceding stage are merged using - the binary commutative associative operator `merge` of type `( -> U)`. - The result of this stage is a single top-level `U` that emerges at the root - of the hierarchy at the `tff.SERVER`. Actual implementations may structure - this step as a cascade of multiple layers. - - * Finally, the `U`-typed result of the reduction performed in the preceding - stage is projected into the result value using `report` as the mapping - function (for example, if the structures being merged consist of counters, - this final step might include computing their ratios). - - Args: - value: A value of a TFF federated type placed at `tff.CLIENTS` to aggregate. - zero: The zero of type `U` in the algebra of reduction operators, as - described above. - accumulate: The reduction operator to use in the first stage of the process. - If `value` is of type `{T}@CLIENTS`, and `zero` is of type `U`, this - operator should be of type `( -> U)`. - merge: The reduction operator to employ in the second stage of the process. - Must be of type `( -> U)`, where `U` is as defined above. - report: The projection operator to use at the final stage of the process to - compute the final result of aggregation. If the intended result to be - returned by `tff.federated_aggregate` is of type `R@SERVER`, this operator - must be of type `(U -> R)`. - - Returns: - A representation on the `tff.SERVER` of the result of aggregating `value` - using the multi-stage process described above. - - Raises: - TypeError: If the arguments are not of the types specified above. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to be aggregated' - ) - - zero = value_impl.to_value(zero, type_spec=None) - py_typecheck.check_type(zero, value_impl.Value) - accumulate = value_impl.to_value( - accumulate, - type_spec=None, - parameter_type_hint=computation_types.StructType([ - zero.type_signature, - value.type_signature.member, # pytype: disable=attribute-error - ]), - ) - merge = value_impl.to_value( - merge, - type_spec=None, - parameter_type_hint=computation_types.StructType( - [accumulate.type_signature.result] # pytype: disable=attribute-error - * 2 - ), - ) - report = value_impl.to_value( - report, - type_spec=None, - parameter_type_hint=merge.type_signature.result, # pytype: disable=attribute-error - ) - for op in [accumulate, merge, report]: - py_typecheck.check_type(op, value_impl.Value) - py_typecheck.check_type(op.type_signature, computation_types.FunctionType) - - if not accumulate.type_signature.parameter[0].is_assignable_from( - zero.type_signature - ): # pytype: disable=attribute-error - raise TypeError( - 'Expected `zero` to be assignable to type {}, ' - 'but was of incompatible type {}.'.format( - accumulate.type_signature.parameter[0], zero.type_signature # pytype: disable=attribute-error - ) - ) - - accumulate_type_expected = type_factory.reduction_op( - accumulate.type_signature.result, # pytype: disable=attribute-error - value.type_signature.member, # pytype: disable=attribute-error - ) - merge_type_expected = type_factory.reduction_op( - accumulate.type_signature.result, # pytype: disable=attribute-error - accumulate.type_signature.result, # pytype: disable=attribute-error - ) - report_type_expected = computation_types.FunctionType( - merge.type_signature.result, # pytype: disable=attribute-error - report.type_signature.result, # pytype: disable=attribute-error - ) - for op_name, op, type_expected in [ - ('accumulate', accumulate, accumulate_type_expected), - ('merge', merge, merge_type_expected), - ('report', report, report_type_expected), - ]: - if not type_expected.is_assignable_from(op.type_signature): - raise TypeError( - 'Expected parameter `{}` to be of type {}, but received {} instead.' - .format(op_name, type_expected, op.type_signature) - ) - - comp = building_block_factory.create_federated_aggregate( - value.comp, zero.comp, accumulate.comp, merge.comp, report.comp - ) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_broadcast(value): - """Broadcasts a federated value from the `tff.SERVER` to the `tff.CLIENTS`. - - Args: - value: A value of a TFF federated type placed at the `tff.SERVER`, all - members of which are equal (the `tff.FederatedType.all_equal` property of - `value` is `True`). - - Returns: - A representation of the result of broadcasting: a value of a TFF federated - type placed at the `tff.CLIENTS`, all members of which are equal. - - Raises: - TypeError: If the argument is not a federated TFF value placed at the - `tff.SERVER`. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.SERVER, 'value to be broadcasted' - ) - - if not value.type_signature.all_equal: # pytype: disable=attribute-error - raise TypeError('The broadcasted value should be equal at all locations.') - - comp = building_block_factory.create_federated_broadcast(value.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_eval(fn, placement): - """Evaluates a federated computation at `placement`, returning the result. - - Args: - fn: A no-arg TFF computation. - placement: The desired result placement (either `tff.SERVER` or - `tff.CLIENTS`). - - Returns: - A federated value with the given placement `placement`. - - Raises: - TypeError: If the arguments are not of the appropriate types. - """ - # TODO: b/113112108 - Verify that neither the value, nor any of its parts - # are of a federated type. - - fn = value_impl.to_value(fn, type_spec=None) - py_typecheck.check_type(fn, value_impl.Value) - py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) - - if fn.type_signature.parameter is not None: # pytype: disable=attribute-error - raise TypeError( - '`federated_eval` expects a `fn` that accepts no arguments, but ' - 'the `fn` provided has a parameter of type {}.'.format( - fn.type_signature.parameter # pytype: disable=attribute-error - ) - ) - - comp = building_block_factory.create_federated_eval(fn.comp, placement) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_map(fn, arg): - """Maps a federated value pointwise using a mapping function. - - The function `fn` is applied separately across the group of devices - represented by the placement type of `arg`. For example, if `value` has - placement type `tff.CLIENTS`, then `fn` is applied to each client - individually. In particular, this operation does not alter the placement of - the federated value. - - Args: - fn: A mapping function to apply pointwise to member constituents of `arg`. - The parameter of this function must be of the same type as the member - constituents of `arg`. - arg: A value of a TFF federated type (or a value that can be implicitly - converted into a TFF federated type, e.g., by zipping) placed at - `tff.CLIENTS` or `tff.SERVER`. - - Returns: - A federated value with the same placement as `arg` that represents the - result of `fn` on the member constituent of `arg`. - - Raises: - TypeError: If the arguments are not of the appropriate types. - """ - # TODO: b/113112108 - Possibly lift the restriction that the mapped value - # must be placed at the server or clients. Would occur after adding support - # for placement labels in the federated types, and expanding the type - # specification of the intrinsic this is based on to work with federated - # values of arbitrary placement. - - arg = value_impl.to_value(arg, type_spec=None) - arg = value_utils.ensure_federated_value(arg, label='value to be mapped') - - fn = value_impl.to_value( - fn, type_spec=None, parameter_type_hint=arg.type_signature.member # pytype: disable=attribute-error - ) - - py_typecheck.check_type(fn, value_impl.Value) - py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) - if not fn.type_signature.parameter.is_assignable_from( # pytype: disable=attribute-error - arg.type_signature.member # pytype: disable=attribute-error - ): - raise TypeError( - 'The mapping function expects a parameter of type {}, but member ' - 'constituents of the mapped value are of incompatible type {}.'.format( - fn.type_signature.parameter, # pytype: disable=attribute-error - arg.type_signature.member, # pytype: disable=attribute-error - ) - ) - - # TODO: b/144384398 - Change structure to one that maps the placement type - # to the building_block function that fits it, in a way that allows the - # appropriate type checks. - if arg.type_signature.placement is placements.SERVER: # pytype: disable=attribute-error - if not arg.type_signature.all_equal: # pytype: disable=attribute-error - raise TypeError( - 'Arguments placed at {} should be equal at all locations.'.format( - placements.SERVER - ) - ) - comp = building_block_factory.create_federated_apply(fn.comp, arg.comp) - elif arg.type_signature.placement is placements.CLIENTS: # pytype: disable=attribute-error - comp = building_block_factory.create_federated_map(fn.comp, arg.comp) - else: - raise TypeError( - 'Expected `arg` to have a type with a supported placement, ' - 'found {}.'.format(arg.type_signature.placement) # pytype: disable=attribute-error - ) - - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_map_all_equal(fn, arg): - """`federated_map` with the `all_equal` bit set in the `arg` and return.""" - # TODO: b/113112108 - Possibly lift the restriction that the mapped value - # must be placed at the clients after adding support for placement labels - # in the federated types, and expanding the type specification of the - # intrinsic this is based on to work with federated values of arbitrary - # placement. - arg = value_impl.to_value(arg, type_spec=None) - arg = value_utils.ensure_federated_value( - arg, placements.CLIENTS, 'value to be mapped' - ) - - fn = value_impl.to_value( - fn, type_spec=None, parameter_type_hint=arg.type_signature.member # pytype: disable=attribute-error - ) - - py_typecheck.check_type(fn, value_impl.Value) - py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) - if not fn.type_signature.parameter.is_assignable_from( - arg.type_signature.member # pytype: disable=attribute-error - ): # pytype: disable=attribute-error - raise TypeError( - 'The mapping function expects a parameter of type {}, but member ' - 'constituents of the mapped value are of incompatible type {}.'.format( - fn.type_signature.parameter, # pytype: disable=attribute-error - arg.type_signature.member, # pytype: disable=attribute-error - ) - ) - - comp = building_block_factory.create_federated_map_all_equal( - fn.comp, arg.comp - ) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_mean(value, weight=None): - """Computes a `tff.SERVER` mean of `value` placed on `tff.CLIENTS`. - - For values `v_1, ..., v_k`, and weights `w_1, ..., w_k`, this means - `sum_{i=1}^k (w_i * v_i) / sum_{i=1}^k w_i`. - - Args: - value: The value of which the mean is to be computed. Must be of a TFF - federated type placed at `tff.CLIENTS`. The value may be structured, e.g., - its member constituents can be named tuples. The tensor types that the - value is composed of must be floating-point or complex. - weight: An optional weight, a TFF federated integer or floating-point tensor - value, also placed at `tff.CLIENTS`. - - Returns: - A representation at the `tff.SERVER` of the mean of the member constituents - of `value`, optionally weighted with `weight` if specified (otherwise, the - member constituents contributed by all clients are equally weighted). - - Raises: - TypeError: If `value` is not a federated TFF value placed at `tff.CLIENTS`, - or if `weight` is not a federated integer or a floating-point tensor with - the matching placement. - """ - # TODO: b/113112108 - Possibly relax the constraints on numeric types, and - # inject implicit casts where appropriate. For instance, we might want to - # allow `np.int32` values as the input, and automatically cast them to - # `np.float32`` before invoking the average, thus producing a floating-point - # result. - - # TODO: b/120439632 - Possibly allow the weight to be either structured or - # non-scalar, e.g., for the case of averaging a convolutional layer, when - # we would want to use a different weight for every filter, and where it - # might be cumbersome for users to have to manually slice and assemble a - # variable. - - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to be averaged' - ) - if not type_analysis.is_average_compatible(value.type_signature): - raise TypeError( - 'The value type {} is not compatible with the average operator.'.format( - value.type_signature - ) - ) - - if weight is not None: - weight = value_impl.to_value(weight, type_spec=None) - weight = value_utils.ensure_federated_value( - weight, placements.CLIENTS, 'weight to use in averaging' - ) - if ( - not isinstance(weight.type_signature, computation_types.FederatedType) - or not isinstance( - weight.type_signature.member, computation_types.TensorType - ) - or not array_shape.is_shape_scalar(weight.type_signature.member.shape) - ): - raise TypeError( - 'The weight type {} is not a federated scalar.'.format( - weight.type_signature - ) - ) - if not np.issubdtype( - weight.type_signature.member.dtype, # pytype: disable=attribute-error - np.integer, - ) and not np.issubdtype( - weight.type_signature.member.dtype, # pytype: disable=attribute-error - np.floating, - ): - raise TypeError( - 'The weight type {} is not a federated integer or floating-point ' - 'tensor.'.format(weight.type_signature) - ) - - weight_comp = None if weight is None else weight.comp - comp = building_block_factory.create_federated_mean(value.comp, weight_comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_min(value: object) -> value_impl.Value: - """Computes a min at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. - - Args: - value: A value of a TFF federated type placed at the `tff.CLIENTS`. - - Returns: - A representation of the min of the member constituents of `value` placed on - the `tff.SERVER`. - - Raises: - ValueError: If the argument is not a federated TFF value placed at - `tff.CLIENTS` compatible with min. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to take min of' - ) - if not type_analysis.is_min_max_compatible(value.type_signature): - raise ValueError( - 'The value type {} is not compatible with the min operator.'.format( - value.type_signature - ) - ) - comp = building_block_factory.create_federated_min(value.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_max(value: object) -> value_impl.Value: - """Computes a max at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. - - Args: - value: A value of a TFF federated type placed at the `tff.CLIENTS`. - - Returns: - A representation of the max of the member constituents of `value` placed on - the `tff.SERVER`. - - Raises: - ValueError: If the argument is not a federated TFF value placed at - `tff.CLIENTS` compatible with max. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to take max of' - ) - if not type_analysis.is_min_max_compatible(value.type_signature): - raise ValueError( - 'The value type {} is not compatible with the max operator.'.format( - value.type_signature - ) - ) - comp = building_block_factory.create_federated_max(value.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_sum(value): - """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. - - To sum integer values with stronger privacy properties, consider using - `tff.federated_secure_sum_bitwidth`. - - Args: - value: A value of a TFF federated type placed at the `tff.CLIENTS`. - - Returns: - A representation of the sum of the member constituents of `value` placed - on the `tff.SERVER`. - - Raises: - TypeError: If the argument is not a federated TFF value placed at - `tff.CLIENTS`. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to be summed' - ) - type_analysis.check_is_sum_compatible(value.type_signature) - comp = building_block_factory.create_federated_sum(value.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_value(value, placement): - """Returns a federated value at `placement`, with `value` as the constituent. - - Deprecation warning: Using `tff.federated_value` with arguments other than - simple Python constants is deprecated. When placing the result of a - `tff.tensorflow.computation`, prefer `tff.federated_eval`. - - Args: - value: A value of a non-federated TFF type to be placed. - placement: The desired result placement (either `tff.SERVER` or - `tff.CLIENTS`). - - Returns: - A federated value with the given placement `placement`, and the member - constituent `value` equal at all locations. - - Raises: - TypeError: If the arguments are not of the appropriate types. - """ - if isinstance(value, value_impl.Value): - warnings.warn( - ( - 'Deprecation warning: Using `tff.federated_value` with arguments' - ' other than simple Python constants is deprecated. When placing' - ' the result of a `tff.tensorflow.computation`, prefer' - ' `tff.federated_eval`.' - ), - DeprecationWarning, - ) - value = value_impl.to_value(value, type_spec=None) - if type_analysis.contains( - value.type_signature, - lambda t: isinstance(t, computation_types.FederatedType), - ): - raise TypeError( - 'Cannot place value {} containing federated types at ' - 'another placement; requested to be placed at {}.'.format( - value, placement - ) - ) - - comp = building_block_factory.create_federated_value(value.comp, placement) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_zip(value): - """Converts an N-tuple of federated values into a federated N-tuple value. - - Args: - value: A value of a TFF named tuple type, the elements of which are - federated values with the same placement. - - Returns: - A federated value placed at the same location as the members of `value`, in - which every member component is a named tuple that consists of the - corresponding member components of the elements of `value`. - - Raises: - TypeError: If the argument is not a named tuple of federated values with the - same placement. - """ - # TODO: b/113112108 - We use the iterate/unwrap approach below because - # our type system is not powerful enough to express the concept of - # "an operation that takes tuples of T of arbitrary length", and therefore - # the intrinsic federated_zip must only take a fixed number of arguments, - # here fixed at 2. There are other potential approaches to getting around - # this problem (e.g. having the operator act on sequences and thereby - # sidestepping the issue) which we may want to explore. - value = value_impl.to_value(value, type_spec=None) - py_typecheck.check_type(value, value_impl.Value) - py_typecheck.check_type(value.type_signature, computation_types.StructType) - - comp = building_block_factory.create_federated_zip(value.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def _select_parameter_mismatch( - param_type, - type_desc, - name, - secure, - expected_type=None, -) -> NoReturn: - """Throws a `TypeError` indicating a mismatched `select` parameter type.""" - secure_string = '_secure' if secure else '' - intrinsic_name = f'federated{secure_string}_select' - message = f'Expected `{intrinsic_name}` parameter `{name}` to be {type_desc}' - if expected_type is None: - raise TypeError(f'{message}, found value of type {param_type}') - else: - raise TypeError( - f'{message}:\n' - + computation_types.type_mismatch_error_message( - param_type, - expected_type, - computation_types.TypeRelation.ASSIGNABLE, - second_is_expected=True, - ) - ) - - -def _check_select_keys_type( - keys_type: computation_types.Type, secure: bool -) -> None: - """Checks the federated select keys types.""" - if ( - not isinstance(keys_type, computation_types.FederatedType) - or keys_type.placement is not placements.CLIENTS - ): - _select_parameter_mismatch( - keys_type, 'a federated value placed at clients', 'client_keys', secure - ) - if not ( - isinstance(keys_type.member, computation_types.TensorType) - and keys_type.member.dtype == np.int32 - and len(keys_type.member.shape) == 1 - and keys_type.member.shape[0] is not None - ): - _select_parameter_mismatch( - keys_type.member, # pytype: disable=attribute-error - 'a rank-1 tensor with statically known shape and `np.int32` dtype', - 'client_keys.type_signature.member', - secure, - ) - - -def federated_select(client_keys, max_key, server_val, select_fn): - """Sends selected values from a server database to clients. - - Args: - client_keys: `tff.CLIENTS`-placed one-dimensional fixed-size non-negative - `int32` keys used to select values from `database` to load for each - client. - max_key: A `tff.SERVER`-placed `int32` indicating the maximum value of any - of `client_keys` (so that all client keys are in the range `[0, max_key]`, - inclusive). Lower values may permit more optimizations. - server_val: `tff.SERVER`-placed value used as an input to `select_fn`. - select_fn: A `tff.Computation` which accepts unplaced `server_val` and a - `int32` client key and returns a value to be sent to the client. - `select_fn` should be deterministic (nonrandom). - - Returns: - `tff.CLIENTS`-placed sequences of values returned from `select_fn`. In each - sequence, the order of values will match the order of keys in the - corresponding `client_keys` tensor. For example, a client with keys - `[1, 2, ...]` will receive a sequence of values - `[select_fn(server_val, 1), select_fn(server_val, 2), ...]`. - - Raises: - TypeError: If `client_keys` is not of type `{int32[N]}@CLIENTS`, if - `max_key` is not of type `int32@SERVER`, if `server_val` is not a - server-placed value (`S@SERVER`), or if `select_fn` is not a function - of type ` -> RESULT`. - """ - return _federated_select( - client_keys, max_key, server_val, select_fn, secure=False - ) - - -def federated_secure_select(client_keys, max_key, server_val, select_fn): - """Sends privately-selected values from a server database to clients. - - Args: - client_keys: `tff.CLIENTS`-placed one-dimensional fixed-size non-negative - `int32` keys used to select values from `database` to load for each - client. - max_key: A `tff.SERVER`-placed `int32` which is guaranteed to be greater - than any of `client_keys`. Lower values may permit more optimizations. - server_val: `tff.SERVER`-placed value used as an input to `select_fn`. - select_fn: A `tff.Computation` which accepts unplaced `server_val` and a - `int32` client key and returns a value to be sent to the client. - `select_fn` should be deterministic (nonrandom). - - Returns: - `tff.CLIENTS`-placed sequences of values returned from `select_fn`. In each - sequence, the order of values will match the order of keys in the - corresponding `client_keys` tensor. For example, a client with keys - `[1, 2, ...]` will receive a sequence of values - `[select_fn(server_val, 1), select_fn(server_val, 2), ...]`. - - Raises: - TypeError: If `client_keys` is not of type `{int32[N]}@CLIENTS`, if - `max_key` is not of type `int32@SERVER`, if `server_val` is not a - server-placed value (`S@SERVER`), or if `select_fn` is not a function - of type ` -> RESULT`. - """ - return _federated_select( - client_keys, max_key, server_val, select_fn, secure=True - ) - - -def _federated_select(client_keys, max_key, server_val, select_fn, secure): - """Internal helper for `federated_select` and `federated_secure_select`.""" - client_keys = value_impl.to_value(client_keys, type_spec=None) - _check_select_keys_type(client_keys.type_signature, secure) - max_key = value_impl.to_value(max_key, type_spec=None) - expected_max_key_type = computation_types.FederatedType( - np.int32, placements.SERVER - ) - if not expected_max_key_type.is_assignable_from(max_key.type_signature): - _select_parameter_mismatch( - max_key.type_signature, - 'a 32-bit unsigned integer placed at server', - 'max_key', - secure, - expected_type=expected_max_key_type, - ) - server_val = value_impl.to_value(server_val, type_spec=None) - server_val = value_utils.ensure_federated_value( - server_val, label='server_val' - ) - expected_server_val_type = computation_types.FederatedType( - computation_types.AbstractType('T'), placements.SERVER - ) - if ( - not isinstance(server_val.type_signature, computation_types.FederatedType) - or server_val.type_signature.placement is not placements.SERVER - ): - _select_parameter_mismatch( - server_val.type_signature, - 'a value placed at server', - 'server_val', - secure, - expected_type=expected_server_val_type, - ) - select_fn_param_type = computation_types.StructType([ - server_val.type_signature.member, # pytype: disable=attribute-error - np.int32, - ]) - select_fn = value_impl.to_value( - select_fn, type_spec=None, parameter_type_hint=select_fn_param_type - ) - expected_select_fn_type = computation_types.FunctionType( - select_fn_param_type, computation_types.AbstractType('U') - ) - if not isinstance( - select_fn.type_signature, computation_types.FunctionType - ) or not select_fn.type_signature.parameter.is_assignable_from( - select_fn_param_type - ): - _select_parameter_mismatch( - select_fn.type_signature, - 'a function from state and key to result', - 'select_fn', - secure, - expected_type=expected_select_fn_type, - ) - comp = building_block_factory.create_federated_select( - client_keys.comp, max_key.comp, server_val.comp, select_fn.comp, secure - ) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_secure_sum(value, max_input): - """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. - - This function computes a sum such that it should not be possible for the - server to learn any clients individual value. The specific algorithm and - mechanism used to compute the secure sum may vary depending on the target - runtime environment the computation is compiled for or executed on. See - https://research.google/pubs/pub47246/ for more information. - - Not all executors support `tff.federated_secure_sum()`; consult the - documentation for the specific executor or executor stack you plan on using - for the specific of how it's handled by that executor. - - The `max_input` argument is the maximum value (inclusive) that may appear in - `value`. *Lower values may allow for improved communication efficiency.* - Attempting to return a `value` higher than `max_input` is invalid, and will - result in a failure at the given client. - - Example: - - ```python - value = tff.federated_value(1, tff.CLIENTS) - result = tff.federated_secure_sum(value, 1) - - value = tff.federated_value((1, 2), tff.CLIENTS) - result = tff.federated_secure_sum(value, (1, 2)) - ``` - - Note: To sum non-integer values or to sum integers with fewer constraints and - weaker privacy properties, consider using `federated_sum`. - - Args: - value: An integer or nested structure of integers placed at `tff.CLIENTS`, - in the range `[0, max_input]`. - max_input: A Python integer or nested structure of integers matching the - structure of `value`. If integer `max_value` is used with a nested - `value`, the same integer is used for each tensor in `value`. - - Returns: - A representation of the sum of the member constituents of `value` placed - on the `tff.SERVER`. - - Raises: - TypeError: If the argument is not a federated TFF value placed at - `tff.CLIENTS`. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to be summed' - ) - type_analysis.check_is_structure_of_integers(value.type_signature) - max_input_value = value_impl.to_value(max_input, type_spec=None) - value_member_type = value.type_signature.member # pytype: disable=attribute-error - max_input_type = max_input_value.type_signature - if not type_analysis.is_single_integer_or_matches_structure( - max_input_type, value_member_type - ): - raise TypeError( - 'Expected `federated_secure_sum` parameter `max_input` to match ' - 'the structure of `value`, with one integer max per tensor in ' - '`value`. Found `value` of `{}` and `max_input` of `{}`.'.format( - value_member_type, max_input_type - ) - ) - if isinstance(max_input_type, computation_types.TensorType) and isinstance( - value_member_type, computation_types.StructType - ): - max_input_value = value_impl.to_value( - structure.map_structure(lambda _: max_input, value_member_type), - type_spec=None, - ) - comp = building_block_factory.create_federated_secure_sum( - value.comp, max_input_value.comp - ) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def federated_secure_sum_bitwidth(value, bitwidth): - """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. - - This function computes a sum such that it should not be possible for the - server to learn any clients individual value. The specific algorithm and - mechanism used to compute the secure sum may vary depending on the target - runtime environment the computation is compiled for or executed on. See - https://research.google/pubs/pub47246/ for more information. - - Not all executors support `tff.federated_secure_sum_bitwidth()`; consult the - documentation for the specific executor or executor stack you plan on using - for the specific of how it's handled by that executor. - - The `bitwidth` argument represents the bitwidth of the aggregand, that is the - bitwidth of the input `value`. The federated secure sum bitwidth (i.e., the - bitwidth of the *sum* of the input `value`s over all clients) will be a - function of this bitwidth and the number of participating clients. - - Example: - - ```python - value = tff.federated_value(1, tff.CLIENTS) - result = tff.federated_secure_sum_bitwidth(value, 2) - - value = tff.federated_value([1, 1], tff.CLIENTS) - result = tff.federated_secure_sum_bitwidth(value, [2, 4]) - - value = tff.federated_value([1, [1, 1]], tff.CLIENTS) - result = tff.federated_secure_sum_bitwidth(value, [2, [4, 8]]) - ``` - - Note: To sum non-integer values or to sum integers with fewer constraints and - weaker privacy properties, consider using `federated_sum`. - - Args: - value: An integer value of a TFF federated type placed at the `tff.CLIENTS`, - in the range [0, 2^bitwidth - 1]. - bitwidth: An integer or nested structure of integers matching the structure - of `value`. If integer `bitwidth` is used with a nested `value`, the same - integer is used for each tensor in `value`. - - Returns: - A representation of the sum of the member constituents of `value` placed - on the `tff.SERVER`. - - Raises: - TypeError: If the argument is not a federated TFF value placed at - `tff.CLIENTS`. - """ - value = value_impl.to_value(value, type_spec=None) - value = value_utils.ensure_federated_value( - value, placements.CLIENTS, 'value to be summed' - ) - type_analysis.check_is_structure_of_integers(value.type_signature) - bitwidth_value = value_impl.to_value(bitwidth, type_spec=None) - value_member_type = value.type_signature.member # pytype: disable=attribute-error - bitwidth_type = bitwidth_value.type_signature - if not type_analysis.is_single_integer_or_matches_structure( - bitwidth_type, value_member_type - ): - raise TypeError( - 'Expected `federated_secure_sum_bitwidth` parameter `bitwidth` to ' - 'match the structure of `value`, with one integer bitwidth per tensor ' - 'in `value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format( - value_member_type, bitwidth_type - ) - ) - if isinstance(bitwidth_type, computation_types.TensorType) and isinstance( - value_member_type, computation_types.StructType - ): - bitwidth_value = value_impl.to_value( - structure.map_structure(lambda _: bitwidth, value_member_type), - type_spec=None, - ) - comp = building_block_factory.create_federated_secure_sum_bitwidth( - value.comp, bitwidth_value.comp - ) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - - -def sequence_map(fn, arg): - """Maps a TFF sequence `value` pointwise using a given function `fn`. - - This function supports two modes of usage: - - * When applied to a non-federated sequence, it maps individual elements of - the sequence pointwise. If the supplied `fn` is of type `T->U` and - the sequence `arg` is of type `T*` (a sequence of `T`-typed elements), - the result is a sequence of type `U*` (a sequence of `U`-typed elements), - with each element of the input sequence individually mapped by `fn`. - In this mode of usage, `sequence_map` behaves like a computatation with type - signature `U,T*> -> U*`. - - * When applied to a federated sequence, `sequence_map` behaves as if it were - individually applied to each member constituent. In this mode of usage, one - can think of `sequence_map` as a specialized variant of `federated_map` that - is designed to work with sequences and allows one to - specify a `fn` that operates at the level of individual elements. - Indeed, under the hood, when `sequence_map` is invoked on a federated type, - it injects `federated_map`, thus - emitting expressions like - `federated_map(a -> sequence_map(fn, x), arg)`. - - Args: - fn: A mapping function to apply pointwise to elements of `arg`. - arg: A value of a TFF type that is either a sequence, or a federated - sequence. - - Returns: - A sequence with the result of applying `fn` pointwise to each - element of `arg`, or if `arg` was federated, a federated sequence - with the result of invoking `sequence_map` on member sequences locally - and independently at each location. - - Raises: - TypeError: If the arguments are not of the appropriate types. - """ - fn = value_impl.to_value(fn, type_spec=None) - py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) - arg = value_impl.to_value(arg, type_spec=None) - - if isinstance(arg.type_signature, computation_types.SequenceType): - comp = building_block_factory.create_sequence_map(fn.comp, arg.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - elif isinstance(arg.type_signature, computation_types.FederatedType): - parameter_type = computation_types.SequenceType(fn.type_signature.parameter) # pytype: disable=attribute-error - result_type = computation_types.SequenceType(fn.type_signature.result) # pytype: disable=attribute-error - intrinsic_type = computation_types.FunctionType( - (fn.type_signature, parameter_type), result_type - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type - ) - intrinsic_impl = value_impl.Value(intrinsic) - local_fn = value_utils.get_curried(intrinsic_impl)(fn) - return federated_map(local_fn, arg) - else: - raise TypeError( - 'Cannot apply `tff.sequence_map()` to a value of type {}.'.format( - arg.type_signature - ) - ) - - -def sequence_reduce(value, zero, op): - """Reduces a TFF sequence `value` given a `zero` and reduction operator `op`. - - This method reduces a set of elements of a TFF sequence `value`, using a given - `zero` in the algebra (i.e., the result of reducing an empty sequence) of some - type `U`, and a reduction operator `op` with type signature `( -> U)` - that incorporates a single `T`-typed element of `value` into the `U`-typed - result of partial reduction. In the special case of `T` equal to `U`, this - corresponds to the classical notion of reduction of a set using a commutative - associative binary operator. The generalized reduction (with `T` not equal to - `U`) requires that repeated application of `op` to reduce a set of `T` always - yields the same `U`-typed result, regardless of the order in which elements - of `T` are processed in the course of the reduction. - - One can also invoke `sequence_reduce` on a federated sequence, in which case - the reductions are performed pointwise; under the hood, we construct an - expression of the form - `federated_map(x -> sequence_reduce(x, zero, op), value)`. See also the - discussion on `sequence_map`. - - Note: When applied to a federated value this function does the reduce - point-wise. - - Args: - value: A value that is either a TFF sequence, or a federated sequence. - zero: The result of reducing a sequence with no elements. - op: An operator with type signature `( -> U)`, where `T` is the type of - the elements of the sequence, and `U` is the type of `zero` to be used in - performing the reduction. - - Returns: - The `U`-typed result of reducing elements in the sequence, or if the `value` - is federated, a federated `U` that represents the result of locally - reducing each member constituent of `value`. - - Raises: - TypeError: If the arguments are not of the types specified above. - """ - value = value_impl.to_value(value, type_spec=None) - zero = value_impl.to_value(zero, type_spec=None) - op = value_impl.to_value(op, type_spec=None) - # Check if the value is a federated sequence that should be reduced - # under a `federated_map`. - if isinstance(value.type_signature, computation_types.FederatedType): - value_member_type = value.type_signature.member - if not isinstance(value_member_type, computation_types.SequenceType): - raise ValueError( - f'Expected a `tff.SequenceType`, found {value_member_type}.' - ) - zero_member_type = zero.type_signature.member - ref_type = computation_types.StructType( - [value_member_type, zero_member_type] - ) - ref = building_blocks.Reference('arg', ref_type) - arg1 = building_blocks.Selection(ref, index=0) - arg2 = building_blocks.Selection(ref, index=1) - call = building_block_factory.create_sequence_reduce(arg1, arg2, op.comp) - fn = building_blocks.Lambda(ref.name, ref.type_signature, call) - fn_value_impl = value_impl.Value(fn) - args = building_blocks.Struct([value.comp, zero.comp]) - return federated_map(fn_value_impl, args) - elif isinstance(value.type_signature, computation_types.SequenceType): - comp = building_block_factory.create_sequence_reduce( - value.comp, zero.comp, op.comp - ) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - else: - raise NotImplementedError(f'Unexpected type found: {value.type_signature}.') - - -def sequence_sum(value): - """Computes a sum of elements in a sequence. - - Args: - value: A value of a TFF type that is either a sequence, or a federated - sequence. - - Returns: - The sum of elements in the sequence. If the argument `value` is of a - federated type, the result is also of a federated type, with the sum - computed locally and independently at each location (see also a discussion - on `sequence_map` and `sequence_reduce`). - - Raises: - TypeError: If the arguments are of wrong or unsupported types. - """ - value = value_impl.to_value(value, type_spec=None) - if isinstance(value.type_signature, computation_types.SequenceType): - element_type = value.type_signature.element - else: - py_typecheck.check_type( - value.type_signature, # pytype: disable=attribute-error - computation_types.FederatedType, - ) - py_typecheck.check_type( - value.type_signature.member, # pytype: disable=attribute-error - computation_types.SequenceType, - ) - element_type = value.type_signature.member.element # pytype: disable=attribute-error - type_analysis.check_is_sum_compatible(element_type) - - if isinstance(value.type_signature, computation_types.SequenceType): - comp = building_block_factory.create_sequence_sum(value.comp) - comp = _bind_comp_as_reference(comp) - return value_impl.Value(comp) - elif isinstance(value.type_signature, computation_types.FederatedType): - intrinsic_type = computation_types.FunctionType( - value.type_signature.member, value.type_signature.member.element - ) - intrinsic = building_blocks.Intrinsic( - intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type - ) - intrinsic_impl = value_impl.Value(intrinsic) - return federated_map(intrinsic_impl, value) - else: - raise TypeError( - 'Cannot apply `tff.sequence_sum()` to a value of type {}.'.format( - value.type_signature - ) - ) diff --git a/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py b/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py deleted file mode 100644 index b4cd5dd6ff..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/intrinsics_test.py +++ /dev/null @@ -1,987 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation_context -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -_INT = computation_types.TensorType(np.int32) -_INT_CLIENTS = computation_types.FederatedType(_INT, placements.CLIENTS) -_INT_SERVER = computation_types.FederatedType(_INT, placements.SERVER) -_FLOAT = computation_types.TensorType(np.float32) -_FLOAT_CLIENTS = computation_types.FederatedType(_FLOAT, placements.CLIENTS) -_FLOAT_SERVER = computation_types.FederatedType(_FLOAT, placements.SERVER) -_STR = computation_types.TensorType(np.str_) -_STR_CLIENTS = computation_types.FederatedType(_STR, placements.CLIENTS) -_ARRAY_INT = computation_types.TensorType(np.int32, shape=[3]) -_ARRAY_INT_CLIENTS = computation_types.FederatedType( - _ARRAY_INT, placements.CLIENTS -) -_ARRAY_INT_SERVER = computation_types.FederatedType( - _ARRAY_INT, placements.SERVER -) -_SEQUENCE_INT = computation_types.SequenceType(np.int32) -_SEQUENCE_INT_CLIENTS = computation_types.FederatedType( - _SEQUENCE_INT, placements.CLIENTS -) -_SEQUENCE_INT_SERVER = computation_types.FederatedType( - _SEQUENCE_INT, placements.SERVER -) -_SEQUENCE_FLOAT = computation_types.SequenceType(np.float32) -_SEQUENCE_FLOAT_CLIENTS = computation_types.FederatedType( - _SEQUENCE_FLOAT, placements.CLIENTS -) -_SEQUENCE_FLOAT_SERVER = computation_types.FederatedType( - _SEQUENCE_FLOAT, placements.SERVER -) -_STRUCT_INT = computation_types.StructWithPythonType( - [np.int32, np.int32, np.int32], list -) -_STRUCT_INT_CLIENTS = computation_types.FederatedType( - _STRUCT_INT, placements.CLIENTS -) -_STRUCT_INT_SERVER = computation_types.FederatedType( - _STRUCT_INT, placements.SERVER -) -_STRUCT_FLOAT = computation_types.StructWithPythonType( - [np.float32, np.float32, np.float32], list -) -_STRUCT_FLOAT_CLIENTS = computation_types.FederatedType( - _STRUCT_FLOAT, placements.CLIENTS -) - - -def _create_context() -> ( - federated_computation_context.FederatedComputationContext -): - return federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - - -def _create_fake_fn( - parameter_type: Optional[computation_types.TensorType], - result_type: computation_types.TensorType, -) -> value_impl.Value: - result = building_blocks.Reference('result', result_type) - parameter_name = None if parameter_type is None else 'arg' - fn = building_blocks.Lambda(parameter_name, parameter_type, result) - return value_impl.Value(fn) - - -def _create_fake_value(type_spec: computation_types.Type) -> value_impl.Value: - value = building_blocks.Reference('value', type_spec) - return value_impl.Value(value) - - -class FederatedBroadcastTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('int_server', _create_fake_value(_INT_SERVER)), - ('sequence_server', _create_fake_value(_SEQUENCE_INT_SERVER)), - ('struct_server', _create_fake_value(_STRUCT_INT_SERVER)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value): - result = intrinsics.federated_broadcast(value) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.CLIENTS, all_equal=True - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ('int_unplaced', _create_fake_value(_INT)), - ('int_clients', _create_fake_value(_INT_CLIENTS)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value): - with self.assertRaises(TypeError): - intrinsics.federated_broadcast(value) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_SERVER) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_broadcast(value) - - -class FederatedEvalTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'int_and_clients', - _create_fake_fn(None, _INT), - placements.CLIENTS, - ), - ( - 'int_and_server', - _create_fake_fn(None, _INT), - placements.SERVER, - ), - ( - 'sequence_and_clieints', - _create_fake_fn(None, _SEQUENCE_INT), - placements.CLIENTS, - ), - ( - 'struct_and_clients', - _create_fake_fn(None, _STRUCT_INT), - placements.CLIENTS, - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, fn, placement): - result = intrinsics.federated_eval(fn, placement) - - expected_type = computation_types.FederatedType( - fn.type_signature.result, placement - ) - self.assertEqual(result.type_signature, expected_type) - - def test_raises_context_error_with_no_federated_context(self): - fn = _create_fake_fn(None, _INT) - placement = placements.CLIENTS - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_eval(fn, placement) - - -class FederatedMapTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'int_clients', - _create_fake_fn(_INT, _FLOAT), - _create_fake_value(_INT_CLIENTS), - _FLOAT_CLIENTS, - ), - ( - 'int_server', - _create_fake_fn(_INT, _FLOAT), - _create_fake_value(_INT_SERVER), - _FLOAT_SERVER, - ), - ( - 'sequence_clients', - _create_fake_fn(_SEQUENCE_INT, _FLOAT), - _create_fake_value(_SEQUENCE_INT_CLIENTS), - _FLOAT_CLIENTS, - ), - ( - 'struct_clients', - _create_fake_fn(_STRUCT_INT, _FLOAT), - _create_fake_value(_STRUCT_INT_CLIENTS), - _FLOAT_CLIENTS, - ), - ( - 'struct_injected_zip', - _create_fake_fn(_STRUCT_INT, _FLOAT), - [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - ], - _FLOAT_CLIENTS, - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, fn, arg, expected_type): - result = intrinsics.federated_map(fn, arg) - - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ( - 'int_unplaced', - _create_fake_fn(_INT, _FLOAT), - _create_fake_value(_INT), - ), - ( - 'struct_different_placements', - _create_fake_fn(_STRUCT_INT, _FLOAT), - [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_FLOAT_SERVER), - ], - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, fn, arg): - with self.assertRaises(TypeError): - intrinsics.federated_map(fn, arg) - - def test_raises_context_error_with_no_federated_context(self): - fn = _create_fake_fn(_INT, _FLOAT) - arg = _create_fake_value(_INT_CLIENTS) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_map(fn, arg) - - -class FederatedSecureSumTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'value_int_clients_and_max_input_int', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT), - ), - ( - 'value_struct_int_clients_and_max_input_int', - _create_fake_value(_STRUCT_INT_CLIENTS), - _create_fake_value(_INT), - ), - ( - 'value_struct_int_clients_and_max_input_struct', - _create_fake_value(_STRUCT_INT_CLIENTS), - _create_fake_value(_STRUCT_INT), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, max_input): - result = intrinsics.federated_secure_sum(value, max_input) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ( - 'value_int_unplaced', - _create_fake_value(_INT), - _create_fake_value(_INT), - ), - ( - 'value_float_clients', - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_value(_INT), - ), - ( - 'value_int_server', - _create_fake_value(_INT_SERVER), - _create_fake_value(_INT), - ), - ( - 'max_input_int_clients', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - ), - ( - 'max_input_int_server', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - ), - ( - 'mismatched_structures', - _create_fake_value( - computation_types.FederatedType( - [np.int32] * 2, placements.CLIENTS - ), - ), - _create_fake_value(computation_types.StructType([np.int32] * 3)), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value, max_input): - with self.assertRaises(TypeError): - intrinsics.federated_secure_sum(value, max_input) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_CLIENTS) - max_input = _create_fake_value(_INT) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_secure_sum(value, max_input) - - -class FederatedSecureSumBitwidthTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'value_int_clients_and_bitwidth_int', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT), - ), - ( - 'value_struct_int_clients_and_bitwidth_int', - _create_fake_value(_STRUCT_INT_CLIENTS), - _create_fake_value(_INT), - ), - ( - 'value_struct_int_clients_and_bitwidth_struct', - _create_fake_value(_STRUCT_INT_CLIENTS), - _create_fake_value(_STRUCT_INT), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, bitwidth): - result = intrinsics.federated_secure_sum_bitwidth(value, bitwidth) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ( - 'value_int_unplaced', - _create_fake_value(_INT), - _create_fake_value(_INT), - ), - ( - 'value_float_clients', - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_value(_INT), - ), - ( - 'value_int_server', - _create_fake_value(_INT_SERVER), - _create_fake_value(_INT), - ), - ( - 'bitwidth_int_clients', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - ), - ( - 'bitwidth_int_server', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - ), - ( - 'mismatched_structures', - _create_fake_value( - computation_types.FederatedType( - [np.int32] * 2, placements.CLIENTS - ), - ), - _create_fake_value(computation_types.StructType([np.int32] * 3)), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value, bitwidth): - with self.assertRaises(TypeError): - intrinsics.federated_secure_sum_bitwidth(value, bitwidth) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_CLIENTS) - bitwidth = _create_fake_value(_INT) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_secure_sum_bitwidth(value, bitwidth) - - -class FederatedSelectTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'struct_int', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - _SEQUENCE_FLOAT_CLIENTS, - ), - ( - 'struct_injected_zip', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - [ - _create_fake_value(_INT_SERVER), - _create_fake_value(_INT_SERVER), - _create_fake_value(_INT_SERVER), - ], - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - _SEQUENCE_FLOAT_CLIENTS, - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result( - self, - client_keys, - max_key, - server_value, - select_fn, - expected_type, - ): - result = intrinsics.federated_select( - client_keys, max_key, server_value, select_fn - ) - - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ( - 'client_keys_array_unplaced', - _create_fake_value(_ARRAY_INT), - _create_fake_value(_INT_SERVER), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'client_keys_array_server', - _create_fake_value(_ARRAY_INT_SERVER), - _create_fake_value(_INT_SERVER), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'max_key_int_unplaced', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'max_key_int_clients', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'max_key_float_server', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_FLOAT_SERVER), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'server_value_struct_unplaced', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - _create_fake_value(_STRUCT_INT), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'server_value_struct_clients', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - _create_fake_value(_STRUCT_INT_CLIENTS), - _create_fake_fn([_STRUCT_INT, _INT], _FLOAT), - ), - ( - 'select_fn_second_parameter_float', - _create_fake_value(_ARRAY_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - _create_fake_value(_STRUCT_INT_SERVER), - _create_fake_fn([_STRUCT_INT, _FLOAT], _FLOAT), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error( - self, client_keys, max_key, server_value, select_fn - ): - with self.assertRaises(TypeError): - intrinsics.federated_select(client_keys, max_key, server_value, select_fn) - - def test_raises_context_error_with_no_federated_context(self): - client_keys = _create_fake_value(_ARRAY_INT_CLIENTS) - max_key = _create_fake_value(_INT_SERVER) - server_value = _create_fake_value(_STRUCT_INT_SERVER) - select_fn = _create_fake_fn([_STRUCT_INT, _INT], _FLOAT) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_select(client_keys, max_key, server_value, select_fn) - - -class FederatedSumTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('int_clients', _create_fake_value(_INT_CLIENTS)), - ('struct_clients', _create_fake_value(_STRUCT_INT_CLIENTS)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value): - result = intrinsics.federated_sum(value) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ('int_unplaced', _create_fake_value(_INT)), - ('str_clients', _create_fake_value(_STR_CLIENTS)), - ('int_server', _create_fake_value(_INT_SERVER)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value): - with self.assertRaises(TypeError): - intrinsics.federated_sum(value) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_CLIENTS) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_sum(value) - - -class FederatedZipTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'struct_clients', - [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - ], - _STRUCT_INT_CLIENTS, - ), - ( - 'struct_server', - [ - _create_fake_value(_INT_SERVER), - _create_fake_value(_INT_SERVER), - _create_fake_value(_INT_SERVER), - ], - _STRUCT_INT_SERVER, - ), - ( - 'struct_different_dtypes', - [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_value(_STR_CLIENTS), - ], - computation_types.FederatedType( - [np.int32, np.float32, np.str_], placements.CLIENTS - ), - ), - ( - 'struct_one_element', - [ - _create_fake_value(_INT_CLIENTS), - ], - computation_types.FederatedType([np.int32], placements.CLIENTS), - ), - ( - 'struct_nested', - [ - _create_fake_value(_INT_CLIENTS), - [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - ], - ], - computation_types.FederatedType( - [np.int32, [np.int32, np.int32]], placements.CLIENTS - ), - ), - ( - 'struct_named', - { - 'a': _create_fake_value(_INT_CLIENTS), - 'b': _create_fake_value(_INT_CLIENTS), - 'c': _create_fake_value(_INT_CLIENTS), - }, - computation_types.FederatedType( - {'a': np.int32, 'b': np.int32, 'c': np.int32}, placements.CLIENTS - ), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, expected_type): - result = intrinsics.federated_zip(value) - - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ('int_unplaced', _create_fake_value(_INT)), - ('int_clients', _create_fake_value(_INT_CLIENTS)), - ('int_server', _create_fake_value(_INT_SERVER)), - ( - 'struct_different_placements', - [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_SERVER), - ], - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value): - with self.assertRaises(TypeError): - intrinsics.federated_zip(value) - - def test_raises_context_error_with_no_federated_context(self): - value = [ - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT_CLIENTS), - ] - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_zip(value) - - -class FederatedMeanTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'value_float_clients_and_weight_none', - _create_fake_value(_FLOAT_CLIENTS), - None, - ), - ( - 'value_float_clients_and_weight_float_clients', - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_value(_FLOAT_CLIENTS), - ), - ( - 'value_struct_int_clients_and_weight_none', - _create_fake_value(_STRUCT_FLOAT_CLIENTS), - None, - ), - ( - 'value_struct_int_clients_and_weight_float_clients', - _create_fake_value(_STRUCT_FLOAT_CLIENTS), - _create_fake_value(_FLOAT_CLIENTS), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, weight): - result = intrinsics.federated_mean(value, weight) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ( - 'value_int_clients', - _create_fake_value(_INT_CLIENTS), - None, - ), - ( - 'value_float_server', - _create_fake_value(_FLOAT_SERVER), - None, - ), - ( - 'weight_str_clients', - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_value(_STR_CLIENTS), - ), - ( - 'weight_float_server', - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_value(_FLOAT_SERVER), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value, weight): - with self.assertRaises(TypeError): - intrinsics.federated_mean(value, weight) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_FLOAT_CLIENTS) - weight = None - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_mean(value, weight) - - -class FederatedMinTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('int_clients', _create_fake_value(_INT_CLIENTS)), - ('struct_clients', _create_fake_value(_STRUCT_INT_CLIENTS)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value): - result = intrinsics.federated_min(value) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ('int_unplaced', _create_fake_value(_INT)), - ('int_server', _create_fake_value(_INT_SERVER)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value): - with self.assertRaises(TypeError): - intrinsics.federated_min(value) - - @parameterized.named_parameters( - ('str_clients', _create_fake_value(_STR_CLIENTS)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_value_error(self, value): - with self.assertRaises(ValueError): - intrinsics.federated_min(value) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_CLIENTS) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_min(value) - - -class FederatedMaxTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('int_clients', _create_fake_value(_INT_CLIENTS)), - ('struct_clients', _create_fake_value(_STRUCT_INT_CLIENTS)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value): - result = intrinsics.federated_max(value) - - expected_type = computation_types.FederatedType( - value.type_signature.member, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ('int_unplaced', _create_fake_value(_INT)), - ('int_server', _create_fake_value(_INT_SERVER)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value): - with self.assertRaises(TypeError): - intrinsics.federated_max(value) - - @parameterized.named_parameters( - ('str_clients', _create_fake_value(_STR_CLIENTS)), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_value_error(self, value): - with self.assertRaises(ValueError): - intrinsics.federated_max(value) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_CLIENTS) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_min(value) - - -class FederatedAggregateTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'int_clients', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_FLOAT), - _create_fake_fn([_FLOAT, _INT], _FLOAT), - _create_fake_fn([_FLOAT, _FLOAT], _FLOAT), - _create_fake_fn(_FLOAT, _STR), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, zero, accumulate, merge, report): - result = intrinsics.federated_aggregate( - value, zero, accumulate, merge, report - ) - - expected_type = computation_types.FederatedType( - report.type_signature.result, placements.SERVER - ) - self.assertEqual(result.type_signature, expected_type) - - @parameterized.named_parameters( - ( - 'zero_mismatched_type', - _create_fake_value(_INT_CLIENTS), - _create_fake_value(_INT), - _create_fake_fn([_FLOAT, _INT], _FLOAT), - _create_fake_fn([_FLOAT, _FLOAT], _FLOAT), - _create_fake_fn(_FLOAT, _STR), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_raises_type_error(self, value, zero, accumulate, merge, report): - with self.assertRaises(TypeError): - intrinsics.federated_aggregate(value, zero, accumulate, merge, report) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT_CLIENTS) - zero = _create_fake_value(_FLOAT) - accumulate = _create_fake_fn([_FLOAT, _INT], _FLOAT) - merge = _create_fake_fn([_FLOAT, _FLOAT], _FLOAT) - report = _create_fake_fn(_FLOAT, _STR) - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_aggregate(value, zero, accumulate, merge, report) - - -class FederatedValueTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'int_and_clients', - _create_fake_value(_INT), - placements.CLIENTS, - ), - ( - 'int_and_server', - _create_fake_value(_INT), - placements.SERVER, - ), - ( - 'sequence_and_clients', - _create_fake_value(_SEQUENCE_INT), - placements.CLIENTS, - ), - ( - 'struct_and_clients', - _create_fake_value(_STRUCT_INT), - placements.CLIENTS, - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, placement): - result = intrinsics.federated_value(value, placement) - - expected_type = computation_types.FederatedType( - value.type_signature, placement, all_equal=True - ) - self.assertEqual(result.type_signature, expected_type) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_INT) - placement = placements.CLIENTS - - with self.assertRaises(context_base.ContextError): - intrinsics.federated_value(value, placement) - - -class SequenceMapTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'sequence_unplaced', - _create_fake_fn(_INT, _FLOAT), - _create_fake_value(_SEQUENCE_INT), - _SEQUENCE_FLOAT, - ), - ( - 'sequence_clients', - _create_fake_fn(_INT, _FLOAT), - _create_fake_value(_SEQUENCE_INT_CLIENTS), - _SEQUENCE_FLOAT_CLIENTS, - ), - ( - 'sequence_server', - _create_fake_fn(_INT, _FLOAT), - _create_fake_value(_SEQUENCE_INT_SERVER), - _SEQUENCE_FLOAT_SERVER, - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, fn, arg, expected_type): - result = intrinsics.sequence_map(fn, arg) - - self.assertEqual(result.type_signature, expected_type) - - def test_raises_context_error_with_no_federated_context(self): - fn = _create_fake_fn(_INT, _FLOAT) - arg = _create_fake_value(_SEQUENCE_INT) - - with self.assertRaises(context_base.ContextError): - intrinsics.sequence_map(fn, arg) - - -class SequenceReduceTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'sequence_unplaced', - _create_fake_value(_SEQUENCE_INT), - _create_fake_value(_FLOAT), - _create_fake_fn([_FLOAT, _INT], _FLOAT), - ), - ( - 'sequence_clients', - _create_fake_value(_SEQUENCE_INT_CLIENTS), - _create_fake_value(_FLOAT_CLIENTS), - _create_fake_fn([_FLOAT, _INT], _FLOAT), - ), - ( - 'sequence_server', - _create_fake_value(_SEQUENCE_INT_SERVER), - _create_fake_value(_FLOAT_SERVER), - _create_fake_fn([_FLOAT, _INT], _FLOAT), - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, zero, op): - result = intrinsics.sequence_reduce(value, zero, op) - - expected_type = zero.type_signature - self.assertEqual(result.type_signature, expected_type) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_SEQUENCE_INT) - zero = _create_fake_value(_FLOAT) - op = _create_fake_fn([_FLOAT, _INT], _FLOAT) - - with self.assertRaises(context_base.ContextError): - intrinsics.sequence_reduce(value, zero, op) - - -class SequenceSumTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'sequence_unplaced', - _create_fake_value(_SEQUENCE_INT), - _INT, - ), - ( - 'sequence_clients', - _create_fake_value(_SEQUENCE_INT_CLIENTS), - _INT_CLIENTS, - ), - ( - 'sequence_server', - _create_fake_value(_SEQUENCE_INT_SERVER), - _INT_SERVER, - ), - ) - @context_stack_test_utils.with_context(_create_context) - def test_returns_result(self, value, expected_type): - result = intrinsics.sequence_sum(value) - - self.assertEqual(result.type_signature, expected_type) - - def test_raises_context_error_with_no_federated_context(self): - value = _create_fake_value(_SEQUENCE_INT) - - with self.assertRaises(context_base.ContextError): - intrinsics.sequence_sum(value) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/federated_context/value_impl.py b/tensorflow_federated/python/core/impl/federated_context/value_impl.py deleted file mode 100644 index f2e30e389c..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/value_impl.py +++ /dev/null @@ -1,393 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Representation of values inside a federated computation.""" - -import abc -from collections.abc import Hashable, Mapping -import itertools -import typing -from typing import Optional, Union - -import attrs - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import array -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.computation import function_utils -from tensorflow_federated.python.core.impl.computation import polymorphic_computation -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.context_stack import symbol_binding_context -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import typed_object - - -def _unfederated(type_signature): - if isinstance(type_signature, computation_types.FederatedType): - return type_signature.member - return type_signature - - -def _is_federated_struct(type_spec: computation_types.Type) -> bool: - return isinstance(type_spec, computation_types.FederatedType) and isinstance( - type_spec.member, computation_types.StructType - ) - - -def _check_struct_or_federated_struct( - vimpl: 'Value', - attribute: str, -): - if not isinstance( - vimpl.type_signature, computation_types.StructType - ) and not _is_federated_struct(vimpl.type_signature): - raise AttributeError( - f'`tff.Value` of non-structural type {vimpl.type_signature} has no ' - f'attribute {attribute}' - ) - - -def _bind_computation_to_reference(comp, op: str): - context = context_stack_impl.context_stack.current - if not isinstance(context, symbol_binding_context.SymbolBindingContext): - raise context_base.ContextError( - '`tff.Value`s should only be used in contexts which can bind ' - 'references, generally a `FederatedComputationContext`. Attempted ' - f'to bind the result of {op} in a context {context} of ' - f'type {type(context)}.' - ) - return context.bind_computation_to_reference(comp) - - -class Value(typed_object.TypedObject, metaclass=abc.ABCMeta): - """A generic base class for values that appear in TFF computations. - - If the value in this class is of `StructType` or `FederatedType` containing a - `StructType`, the inner fields can be accessed by name - (e.g. `y = my_value_impl.y`). - """ - - def __init__( - self, - comp: building_blocks.ComputationBuildingBlock, - ): - """Constructs a value of the given type. - - Args: - comp: An instance of building_blocks.ComputationBuildingBlock that - contains the logic that computes this value. - """ - super() - py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) - self._comp = comp - - @property - def type_signature(self): - return self._comp.type_signature - - @property - def comp(self) -> building_blocks.ComputationBuildingBlock: - return self._comp - - def __repr__(self): - return repr(self._comp) - - def __str__(self): - return str(self._comp) - - def __dir__(self): - attributes = ['type_signature', 'comp'] - type_signature = _unfederated(self.type_signature) - if isinstance(type_signature, computation_types.StructType): - attributes.extend(dir(type_signature)) - return attributes - - def __getattr__(self, name): - py_typecheck.check_type(name, str) - _check_struct_or_federated_struct(self, name) - if _is_federated_struct(self.type_signature): - if name not in structure.name_list(self.type_signature.member): # pytype: disable=attribute-error - raise AttributeError( - "There is no such attribute '{}' in this federated tuple. Valid " - 'attributes: ({})'.format( - name, ', '.join(dir(self.type_signature.member)) # pytype: disable=attribute-error - ) - ) - - return Value( - building_block_factory.create_federated_getattr_call(self._comp, name) - ) - if name not in dir(self.type_signature): - attributes = ', '.join(dir(self.type_signature)) - raise AttributeError( - f"There is no such attribute '{name}' in this tuple. Valid " - f'attributes: ({attributes})' - ) - if isinstance(self._comp, building_blocks.Struct): - return Value(getattr(self._comp, name)) - return Value(building_blocks.Selection(self._comp, name=name)) - - def __bool__(self): - raise TypeError( - 'Federated computation values do not support boolean operations. ' - 'If you were attempting to perform logic on tensors, consider moving ' - 'this logic into a `tff.tensorflow.computation`.' - ) - - def __len__(self): - type_signature = _unfederated(self.type_signature) - if not isinstance(type_signature, computation_types.StructType): - raise TypeError( - 'Operator len() is only supported for (possibly federated) structure ' - 'types, but the object on which it has been invoked is of type {}.' - .format(self.type_signature) - ) - return len(type_signature) - - def __getitem__(self, key: Union[int, str, slice]): - py_typecheck.check_type(key, (int, str, slice)) - if isinstance(key, str): - return getattr(self, key) - if _is_federated_struct(self.type_signature): - return Value( - building_block_factory.create_federated_getitem_call(self._comp, key), - ) - if not isinstance(self.type_signature, computation_types.StructType): - raise TypeError( - 'Operator getitem() is only supported for structure types, but the ' - 'object on which it has been invoked is of type {}.'.format( - self.type_signature - ) - ) - elem_length = len(self.type_signature) - if isinstance(key, int): - if key < 0 or key >= elem_length: - raise IndexError( - 'The index of the selected element {} is out of range.'.format(key) - ) - if isinstance(self._comp, building_blocks.Struct): - return Value(self._comp[key]) - else: - return Value(building_blocks.Selection(self._comp, index=key)) - elif isinstance(key, slice): - index_range = range(*key.indices(elem_length)) - if not index_range: - raise IndexError( - 'Attempted to slice 0 elements, which is not currently supported.' - ) - return to_value([self[k] for k in index_range], None) - - def __iter__(self): - type_signature = _unfederated(self.type_signature) - if not isinstance(type_signature, computation_types.StructType): - raise TypeError( - 'Operator iter() is only supported for (possibly federated) ' - 'structure types, but the object on which it has been invoked is of ' - f'type {self.type_signature}.' - ) - for index in range(len(type_signature)): - yield self[index] - - def __call__(self, *args, **kwargs): - if not isinstance(self.type_signature, computation_types.FunctionType): - raise SyntaxError( - 'Function-like invocation is only supported for values of functional ' - 'types, but the value being invoked is of type {} that does not ' - 'support invocation.'.format(self.type_signature) - ) - if args or kwargs: - args = [to_value(x, None) for x in args] - kwargs = {k: to_value(v, None) for k, v in kwargs.items()} - arg = function_utils.pack_args( - self.type_signature.parameter, # pytype: disable=attribute-error - args, - kwargs, - ) - arg = to_value(arg, None).comp - else: - arg = None - call = building_blocks.Call(self._comp, arg) - ref = _bind_computation_to_reference(call, 'calling a `tff.Value`') - return Value(ref) - - -def _dictlike_items_to_value(items, type_spec, container_type) -> Value: - elements = [] - for i, (k, v) in enumerate(items): - element_type = None if type_spec is None else type_spec[i] # pytype: disable=unsupported-operands - element_value = to_value(v, element_type) - elements.append((k, element_value.comp)) - return Value(building_blocks.Struct(elements, container_type)) - - -def to_value( - arg: object, - type_spec: Optional[computation_types.Type], - *, - parameter_type_hint=None, - zip_if_needed: bool = False, -) -> Value: - """Converts the argument into an instance of the abstract class `tff.Value`. - - Instances of `tff.Value` represent TFF values that appear internally in - federated computations. This helper function can be used to wrap a variety of - Python objects as `tff.Value` instances to allow them to be passed as - arguments, used as functions, or otherwise manipulated within bodies of - federated computations. - - At the moment, the supported types include: - - * Simple constants of `str`, `int`, `float`, and `bool` types, mapped to - values of a TFF tensor type. - - * Numpy arrays (`np.ndarray` objects), also mapped to TFF tensors. - - * Dictionaries (`collections.OrderedDict` and unordered `dict`), `list`s, - `tuple`s, `namedtuple`s, and `Struct`s, all of which are mapped to - TFF tuple type. - - * Computations (constructed with either the `tff.tensorflow.computation` or - with the `tff.federated_computation` decorator), typically mapped to TFF - functions. - - * Placement literals (`tff.CLIENTS`, `tff.SERVER`), mapped to values of the - TFF placement type. - - This function is also invoked when attempting to execute a TFF computation. - All arguments supplied in the invocation are converted into TFF values prior - to execution. The types of Python objects that can be passed as arguments to - computations thus matches the types listed here. - - Args: - arg: An instance of one of the Python types that are convertible to TFF - values (instances of `tff.Value`). - type_spec: An optional type specifier that allows for disambiguating the - target type (e.g., when two TFF types can be mapped to the same Python - representations). If not specified, TFF tried to determine the type of the - TFF value automatically. - parameter_type_hint: An optional `tff.Type` or value convertible to it by - `tff.to_type()` which specifies an argument type to use in the case that - `arg` is a `polymorphic_computation.PolymorphicComputation`. - zip_if_needed: If `True`, attempt to coerce the result of `to_value` to - match `type_spec` by applying `intrinsics.federated_zip` to appropriate - elements. - - Returns: - An instance of `tff.Value` as described above. - - Raises: - TypeError: if `arg` is of an unsupported type, or of a type that does not - match `type_spec`. Raises explicit error message if TensorFlow constructs - are encountered, as TensorFlow code should be sealed away from TFF - federated context. - """ - # TODO: b/224484886 - Downcasting to all handled types. - arg = typing.cast( - Union[ - None, - Value, - building_blocks.ComputationBuildingBlock, - placements.PlacementLiteral, - computation_impl.ConcreteComputation, - polymorphic_computation.PolymorphicComputation, - computation_types.SequenceType, - structure.Struct, - py_typecheck.SupportsNamedTuple, - Mapping[Hashable, object], - tuple[object, ...], - list[object], - array.Array, - ], - arg, - ) - if isinstance(arg, Value): - result = arg - elif isinstance(arg, building_blocks.ComputationBuildingBlock): - result = Value(arg) - elif isinstance(arg, placements.PlacementLiteral): - result = Value(building_blocks.Placement(arg)) - elif isinstance( - arg, - ( - computation_impl.ConcreteComputation, - polymorphic_computation.PolymorphicComputation, - ), - ): - if isinstance(arg, polymorphic_computation.PolymorphicComputation): - if parameter_type_hint is None: - raise TypeError( - 'Polymorphic computations cannot be converted to `tff.Value`s ' - 'without a type hint. Consider explicitly specifying the ' - 'argument types of a computation before passing it to a ' - 'function that requires a `tff.Value` (such as a TFF intrinsic ' - 'like `federated_map`). If you are a TFF developer and think ' - 'this should be supported, consider providing ' - '`parameter_type_hint` as an argument to the encompassing ' - '`to_value` conversion.' - ) - parameter_type_hint = computation_types.to_type(parameter_type_hint) - arg = arg.fn_for_argument_type(parameter_type_hint) - py_typecheck.check_type(arg, computation_impl.ConcreteComputation) - result = Value(arg.to_compiled_building_block()) - elif isinstance(arg, structure.Struct): - items = structure.iter_elements(arg) - result = _dictlike_items_to_value(items, type_spec, None) - elif isinstance(arg, py_typecheck.SupportsNamedTuple): - items = arg._asdict().items() - result = _dictlike_items_to_value(items, type_spec, type(arg)) - elif attrs.has(type(arg)): - items = attrs.asdict(arg, recurse=False).items() - result = _dictlike_items_to_value(items, type_spec, type(arg)) - elif isinstance(arg, Mapping): - result = _dictlike_items_to_value(arg.items(), type_spec, type(arg)) - elif isinstance(arg, (tuple, list)) and not isinstance( - type_spec, computation_types.SequenceType - ): - items = zip(itertools.repeat(None), arg) - result = _dictlike_items_to_value(items, type_spec, type(arg)) - elif isinstance(arg, typing.get_args(array.Array)): - if type_spec is None: - type_spec = type_conversions.infer_type(arg) - if not isinstance(type_spec, computation_types.TensorType): - raise ValueError(f'Expected a `tff.TensorType`, found {type_spec}.') - literal = building_blocks.Literal(arg, type_spec) - result = Value(literal) - else: - raise TypeError( - 'Expected a Python types that is convertible to a `tff.Value`, found' - f' {type(arg)}. If this is backend-specific constructs, it was' - ' encountered in a federated context and TFF does not support mixing' - ' backend-specific and federated logic. Please wrap any ' - ' backend-specific constructs in a computation function.' - ) - py_typecheck.check_type(result, Value) - if type_spec is not None and not type_spec.is_assignable_from( - result.type_signature - ): - if zip_if_needed: - # Returns `None` if such a zip can't be performed. - zipped_comp = building_block_factory.zip_to_match_type( - comp_to_zip=result.comp, target_type=type_spec - ) - if zipped_comp is not None: - return Value(zipped_comp) - raise computation_types.TypeNotAssignableError( - type_spec, result.type_signature - ) - return result diff --git a/tensorflow_federated/python/core/impl/federated_context/value_impl_test.py b/tensorflow_federated/python/core/impl/federated_context/value_impl_test.py deleted file mode 100644 index 430ca31dcb..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/value_impl_test.py +++ /dev/null @@ -1,629 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections - -from absl.testing import absltest -from absl.testing import parameterized -import attrs -import numpy as np - -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -@attrs.define -class TestAttrClass: - x: object - y: object - - -class ValueTest(parameterized.TestCase): - - def run(self, result=None): - fc_context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - with context_stack_impl.context_stack.install(fc_context): - super().run(result) - - def bound_symbols(self): - return context_stack_impl.context_stack.current.symbol_bindings - - def test_raises_on_boolean_ops(self): - x_comp = building_blocks.Reference('foo', np.bool_) - x = value_impl.Value(x_comp) - with self.assertRaises(TypeError): - assert x - - def test_value_impl_with_reference(self): - x_comp = building_blocks.Reference('foo', np.int32) - x = value_impl.Value(x_comp) - self.assertIs(x.comp, x_comp) - self.assertEqual(str(x.type_signature), 'int32') - self.assertEqual(repr(x), "Reference('foo', TensorType(np.int32))") - self.assertEqual(str(x), 'foo') - with self.assertRaises(SyntaxError): - x(10) - - def test_value_impl_with_selection(self): - x = value_impl.Value( - building_blocks.Reference('foo', [('bar', np.int32), ('baz', np.bool_)]) - ) - self.assertContainsSubset(['bar', 'baz'], dir(x)) - self.assertLen(x, 2) - y = x.bar - self.assertIsInstance(y, value_impl.Value) - self.assertEqual(str(y.type_signature), 'int32') - self.assertEqual(str(y), 'foo.bar') - z = x['baz'] - self.assertEqual(str(z.type_signature), 'bool') - self.assertEqual(str(z), 'foo.baz') - with self.assertRaises(AttributeError): - _ = x.bak - x0 = x[0] - self.assertIsInstance(x0, value_impl.Value) - self.assertEqual(str(x0.type_signature), 'int32') - self.assertEqual(str(x0), 'foo[0]') - x1 = x[1] - self.assertEqual(str(x1.type_signature), 'bool') - self.assertEqual(str(x1), 'foo[1]') - with self.assertRaises(IndexError): - _ = x[2] - with self.assertRaises(IndexError): - _ = x[-1] - self.assertEqual(','.join(str(e) for e in iter(x)), 'foo[0],foo[1]') - self.assertEqual( - ','.join(str(e.type_signature) for e in iter(x)), 'int32,bool' - ) - with self.assertRaises(SyntaxError): - x(10) - - def test_value_impl_with_tuple(self): - x_comp = building_blocks.Reference('foo', np.int32) - y_comp = building_blocks.Reference('bar', np.bool_) - z = value_impl.Value(building_blocks.Struct([x_comp, ('y', y_comp)])) - self.assertIsInstance(z, value_impl.Value) - self.assertEqual(str(z.type_signature), '') - self.assertEqual(str(z), '') - self.assertContainsSubset(['y'], dir(z)) - self.assertEqual(str(z.y), 'bar') - self.assertIs(z.y.comp, y_comp) - self.assertLen(z, 2) - self.assertEqual(str(z[0]), 'foo') - self.assertIs(z[0].comp, x_comp) - self.assertEqual(str(z['y']), 'bar') - self.assertIs(z['y'].comp, y_comp) - self.assertEqual(','.join(str(e) for e in iter(z)), 'foo,bar') - with self.assertRaises(SyntaxError): - z(10) - - def test_value_impl_with_call(self): - x = value_impl.Value( - building_blocks.Reference( - 'foo', computation_types.FunctionType(np.int32, np.bool_) - ), - ) - y = value_impl.Value(building_blocks.Reference('bar', np.int32)) - z = x(y) - self.assertIsInstance(z, value_impl.Value) - self.assertEqual(str(z.type_signature), 'bool') - self.assertEqual(str(z), 'fc_FEDERATED_symbol_0') - bound_symbols = self.bound_symbols() - self.assertLen(bound_symbols, 1) - self.assertEqual(bound_symbols[0][0], str(z)) - self.assertEqual(str(bound_symbols[0][1]), 'foo(bar)') - with self.assertRaises(TypeError): - x() - w = value_impl.Value(building_blocks.Reference('bak', np.float32)) - with self.assertRaises(TypeError): - x(w) - - def test_value_impl_with_lambda(self): - arg_name = 'arg' - arg_type = [ - ('f', computation_types.FunctionType(np.int32, np.int32)), - ('x', np.int32), - ] - result_value = (lambda arg: arg.f(arg.f(arg.x)))( - value_impl.Value(building_blocks.Reference(arg_name, arg_type)) - ) - self.assertIsInstance(result_value, value_impl.Value) - self.assertEqual(str(result_value.type_signature), 'int32') - self.assertEqual(str(result_value), 'fc_FEDERATED_symbol_1') - bound_symbols = self.bound_symbols() - self.assertLen(bound_symbols, 2) - self.assertEqual(bound_symbols[1][0], 'fc_FEDERATED_symbol_1') - self.assertEqual(str(bound_symbols[1][1]), 'arg.f(fc_FEDERATED_symbol_0)') - self.assertEqual(bound_symbols[0][0], 'fc_FEDERATED_symbol_0') - self.assertEqual(str(bound_symbols[0][1]), 'arg.f(arg.x)') - - def test_to_value_for_tuple(self): - x = value_impl.Value( - building_blocks.Reference('foo', np.int32), - ) - y = value_impl.Value( - building_blocks.Reference('bar', np.bool_), - ) - v = value_impl.to_value((x, y), type_spec=None) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), '') - - def test_to_value_for_attrs_class(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.int32)) - v = value_impl.to_value(TestAttrClass(x, y), type_spec=None) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), '') - - def test_to_value_for_nested_attrs_class(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.int32)) - v = value_impl.to_value( - TestAttrClass(TestAttrClass(x, y), TestAttrClass(x, y)), type_spec=None - ) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), ',y=>') - - def test_to_value_for_list(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.bool_)) - v = value_impl.to_value([x, y], type_spec=None) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), '') - - def test_to_value_for_ordered_dict(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.bool_)) - v = value_impl.to_value( - collections.OrderedDict([('b', y), ('a', x)]), type_spec=None - ) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), '') - - def test_to_value_for_named_tuple(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.bool_)) - v = value_impl.to_value( - collections.namedtuple('_', 'a b')(x, y), type_spec=None - ) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), '') - - def test_to_value_for_structure(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.bool_)) - v = value_impl.to_value( - structure.Struct([('a', x), ('b', y)]), type_spec=None - ) - self.assertIsInstance(v, value_impl.Value) - self.assertEqual(str(v), '') - - def test_to_value_for_placements(self): - clients = value_impl.to_value(placements.CLIENTS, type_spec=None) - self.assertIsInstance(clients, value_impl.Value) - self.assertEqual(str(clients.type_signature), 'placement') - self.assertEqual(str(clients), 'CLIENTS') - - def test_to_value_for_computations(self): - type_spec = computation_types.TensorType(np.int32) - computation_proto = computation_factory.create_lambda_identity(type_spec) - computation = computation_impl.ConcreteComputation( - computation_proto=computation_proto, - context_stack=context_stack_impl.context_stack, - ) - - value = value_impl.to_value(computation, type_spec=None) - - self.assertIsInstance(value, value_impl.Value) - expected_type = computation_types.FunctionType(type_spec, type_spec) - self.assertEqual(value.type_signature, expected_type) - - def test_to_value_with_string(self): - value = value_impl.to_value('a', computation_types.TensorType(np.str_)) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'str') - - def test_to_value_with_int(self): - value = value_impl.to_value(1, computation_types.TensorType(np.int32)) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'int32') - - def test_to_value_with_float(self): - value = value_impl.to_value(1.0, computation_types.TensorType(np.float32)) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'float32') - - def test_to_value_with_bool(self): - value = value_impl.to_value(True, computation_types.TensorType(np.bool_)) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'bool') - - def test_to_value_with_np_int32(self): - value = value_impl.to_value( - np.int32(1), computation_types.TensorType(np.int32) - ) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'int32') - - def test_to_value_with_np_int64(self): - value = value_impl.to_value( - np.int64(1), computation_types.TensorType(np.int64) - ) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'int64') - - def test_to_value_with_np_float32(self): - value = value_impl.to_value( - np.float32(1.0), computation_types.TensorType(np.float32) - ) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'float32') - - def test_to_value_with_np_float64(self): - value = value_impl.to_value( - np.float64(1.0), computation_types.TensorType(np.float64) - ) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'float64') - - def test_to_value_with_np_bool(self): - value = value_impl.to_value( - np.bool_(1.0), computation_types.TensorType(np.bool_) - ) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'bool') - - def test_to_value_with_np_ndarray(self): - value = value_impl.to_value( - np.ndarray(shape=(2, 0), dtype=np.int32), - computation_types.TensorType(np.int32, [2, 0]), - ) - self.assertIsInstance(value, value_impl.Value) - self.assertEqual(str(value.type_signature), 'int32[2,0]') - - def test_to_value_with_list_of_ints(self): - with self.assertRaises(TypeError): - value_impl.to_value([1, 2, 3], computation_types.SequenceType(np.int32)) - - def test_to_value_raises_value_error(self): - with self.assertRaises(ValueError): - value_impl.to_value(10, computation_types.TensorType(np.bool_)) - - def test_tf_mapping_raises_helpful_error(self): - with self.assertRaisesRegex( - TypeError, - 'Expected a Python types that is convertible to a `tff.Value`', - ): - value_impl.to_value(object(), None) - - def test_slicing_support_namedtuple(self): - x = value_impl.Value(building_blocks.Reference('foo', np.int32)) - y = value_impl.Value(building_blocks.Reference('bar', np.bool_)) - v = value_impl.to_value(collections.namedtuple('_', 'a b')(x, y), None) - sliced_v = v[: int(len(v) / 2)] - self.assertIsInstance(sliced_v, value_impl.Value) - sliced_v = v[:4:2] - self.assertEqual(str(sliced_v), '') - self.assertIsInstance(sliced_v, value_impl.Value) - sliced_v = v[4::-1] - self.assertEqual(str(sliced_v), '') - self.assertIsInstance(sliced_v, value_impl.Value) - with self.assertRaisesRegex(IndexError, 'slice 0 elements'): - _ = v[2:4] - - def test_slicing_fails_non_namedtuple(self): - v = value_impl.to_value(np.ones([10, 10, 10], dtype=np.float32), None) - with self.assertRaisesRegex(TypeError, 'only supported for structure'): - _ = v[:1] - - def test_slicing_support_non_tuple_underlying_comp(self): - test_computation_building_blocks = building_blocks.Reference( - 'test', [np.int32] * 5 - ) - v = value_impl.Value(test_computation_building_blocks) - sliced_v = v[:4:2] - self.assertIsInstance(sliced_v, value_impl.Value) - sliced_v = v[4:2:-1] - self.assertIsInstance(sliced_v, value_impl.Value) - with self.assertRaisesRegex(IndexError, 'slice 0 elements'): - _ = v[2:4:-1] - - @parameterized.named_parameters(('list', list), ('tuple', tuple)) - def test_slicing_tuple_values_from_front(self, sequence_type): - def _to_value(cbb): - return value_impl.to_value(cbb, None) - - t = sequence_type(range(0, 50, 10)) - v = _to_value(t) - - self.assertEqual((str(v.type_signature)), '') - self.assertEqual(str(v[:]), str(v)) - - sliced = v[:2] - self.assertEqual((str(sliced.type_signature)), '') - self.assertEqual(str(sliced), '<0,10>') - - expected_symbol_bindings = [ - ('fc_FEDERATED_symbol_0', [r'( -> 0)()']), - ('fc_FEDERATED_symbol_1', [r'( -> 10)()']), - ('fc_FEDERATED_symbol_2', [r'( -> 20)()']), - ('fc_FEDERATED_symbol_3', [r'( -> 30)()']), - ('fc_FEDERATED_symbol_4', [r'( -> 40)()']), - ] - - bindings = self.bound_symbols() - for (bound_name, comp), (expected_name, expected_regex) in zip( - bindings, expected_symbol_bindings - ): - self.assertEqual(bound_name, expected_name) - self.assertRegexMatch(comp.compact_representation(), expected_regex) - - @parameterized.named_parameters(('list', list), ('tuple', tuple)) - def test_slicing_tuple_values_from_back(self, sequence_type): - def _to_value(cbb): - return value_impl.to_value(cbb, None) - - t = sequence_type(range(0, 50, 10)) - v = _to_value(t) - - self.assertEqual((str(v.type_signature)), '') - self.assertEqual(str(v[:]), str(v)) - - sliced = v[-3:] - self.assertEqual((str(sliced.type_signature)), '') - self.assertEqual(str(sliced), '<20,30,40>') - - expected_symbol_bindings = [ - ('fc_FEDERATED_symbol_0', [r'( -> 0)()']), - ('fc_FEDERATED_symbol_1', [r'( -> 10)()']), - ('fc_FEDERATED_symbol_2', [r'( -> 20)()']), - ('fc_FEDERATED_symbol_3', [r'( -> 30)()']), - ('fc_FEDERATED_symbol_4', [r'( -> 40)()']), - ] - - bindings = self.bound_symbols() - for (bound_name, comp), (expected_name, expected_regex) in zip( - bindings, expected_symbol_bindings - ): - self.assertEqual(bound_name, expected_name) - self.assertRegexMatch(comp.compact_representation(), expected_regex) - - @parameterized.named_parameters(('list', list), ('tuple', tuple)) - def test_slicing_tuple_values_skipping_steps(self, sequence_type): - def _to_value(val): - return value_impl.to_value(val, None) - - t = sequence_type(range(0, 50, 10)) - v = _to_value(t) - - sliced = v[::2] - self.assertEqual((str(sliced.type_signature)), '') - self.assertEqual(str(sliced), '<0,20,40>') - - expected_symbol_bindings = [ - ('fc_FEDERATED_symbol_0', [r'( -> 0)()']), - ('fc_FEDERATED_symbol_1', [r'( -> 10)()']), - ('fc_FEDERATED_symbol_2', [r'( -> 20)()']), - ('fc_FEDERATED_symbol_3', [r'( -> 30)()']), - ('fc_FEDERATED_symbol_4', [r'( -> 40)()']), - ] - - bindings = self.bound_symbols() - for (bound_name, comp), (expected_name, expected_regex) in zip( - bindings, expected_symbol_bindings - ): - self.assertEqual(bound_name, expected_name) - self.assertRegexMatch(comp.compact_representation(), expected_regex) - - def test_getitem_resolution_federated_value_clients(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [np.int32, np.bool_], placements.CLIENTS, False - ), - ), - None, - ) - self.assertEqual( - str(federated_value.type_signature), '{}@CLIENTS' - ) - federated_attribute = federated_value[0] - self.assertEqual(str(federated_attribute.type_signature), '{int32}@CLIENTS') - - def test_getitem_federated_slice_constructs_comp_clients(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [np.int32, np.bool_], placements.CLIENTS, False - ), - ), - None, - ) - self.assertEqual( - str(federated_value.type_signature), '{}@CLIENTS' - ) - identity = federated_value[:] - self.assertEqual(str(identity.type_signature), '{}@CLIENTS') - self.assertEqual(str(identity), 'federated_map(<(x -> ),test>)') - - def test_getitem_resolution_federated_value_server(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [np.int32, np.bool_], placements.SERVER, True - ), - ), - None, - ) - self.assertEqual(str(federated_value.type_signature), '@SERVER') - federated_attribute = federated_value[0] - self.assertEqual(str(federated_attribute.type_signature), 'int32@SERVER') - - def test_getitem_federated_slice_constructs_comp_server(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [np.int32, np.bool_], placements.SERVER, True - ), - ), - None, - ) - self.assertEqual(str(federated_value.type_signature), '@SERVER') - identity = federated_value[:] - self.assertEqual(str(identity.type_signature), '@SERVER') - self.assertEqual( - str(identity), 'federated_apply(<(x -> ),test>)' - ) - - def test_getitem_key_resolution(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placements.SERVER, True - ), - ), - None, - ) - self.assertEqual( - str(federated_value.type_signature), '@SERVER' - ) - federated_attribute = federated_value['a'] - self.assertEqual(str(federated_attribute.type_signature), 'int32@SERVER') - with self.assertRaises(AttributeError): - _ = federated_value['badkey'] - - def test_getattr_resolution_federated_value_server(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placements.SERVER, True - ), - ), - None, - ) - self.assertEqual( - str(federated_value.type_signature), '@SERVER' - ) - federated_attribute = federated_value.a - self.assertEqual(str(federated_attribute.type_signature), 'int32@SERVER') - - def test_getattr_resolution_federated_value_clients(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placements.CLIENTS, False - ), - ), - None, - ) - self.assertEqual( - str(federated_value.type_signature), '{}@CLIENTS' - ) - federated_attribute = federated_value.a - self.assertEqual(str(federated_attribute.type_signature), '{int32}@CLIENTS') - - def test_getattr_raises_federated_value_unknown_attr(self): - federated_value_clients = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placements.CLIENTS, True - ), - ), - None, - ) - self.assertEqual( - str(federated_value_clients.type_signature), '@CLIENTS' - ) - with self.assertRaisesRegex( - AttributeError, r'There is no such attribute \'c\'' - ): - _ = federated_value_clients.c - federated_value_server = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placements.SERVER, True - ), - ), - None, - ) - self.assertEqual( - str(federated_value_server.type_signature), '@SERVER' - ) - with self.assertRaisesRegex( - AttributeError, r'There is no such attribute \'c\'' - ): - _ = federated_value_server.c - - def test_getattr_federated_value_with_none_default_missing_name(self): - federated_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.FederatedType( - [('a', np.int32), ('b', np.bool_)], placements.SERVER, True - ), - ), - None, - ) - self.assertEqual( - str(federated_value.type_signature), '@SERVER' - ) - missing_attr = getattr(federated_value, 'c', None) - self.assertIsNone(missing_attr) - - def test_getattr_non_federated_value_with_none_default_missing_name(self): - struct_value = value_impl.to_value( - building_blocks.Reference( - 'test', - computation_types.StructType([('a', np.int32), ('b', np.bool_)]), - ), - None, - ) - self.assertEqual(str(struct_value.type_signature), '') - missing_attr = getattr(struct_value, 'c', None) - self.assertIsNone(missing_attr) - - def test_value_impl_dir(self): - x_comp = building_blocks.Reference('foo', np.int32) - x = value_impl.Value(x_comp) - - result = dir(x) - self.assertIsInstance(result, list) - self.assertNotEmpty(result) - self.assertIn('type_signature', result) - - def test_value_impl_help(self): - x_comp = building_blocks.Reference('foo', np.int32) - x = value_impl.Value(x_comp) - help(x) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/federated_context/value_utils.py b/tensorflow_federated/python/core/impl/federated_context/value_utils.py deleted file mode 100644 index dffb23bd61..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/value_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities file for functions with TFF `Value`s as inputs and outputs.""" - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_block_factory -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.types import computation_types - - -def get_curried(fn): - """Returns a curried version of function `fn` that takes a parameter tuple. - - For functions `fn` of types -> U, the result is a function - of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )). - - Note: No attempt is made at avoiding naming conflicts in cases where `fn` - contains references. The arguments of the curriend function are named `argN` - with `N` starting at 0. - - Args: - fn: A value of a functional TFF type. - - Returns: - A value that represents the curried form of `fn`. - """ - py_typecheck.check_type(fn, value_impl.Value) - if not isinstance(fn.type_signature, computation_types.FunctionType): - raise ValueError( - f'Expected a `tff.FunctionType`, found {fn.type_signature}.' - ) - if not isinstance(fn.type_signature.parameter, computation_types.StructType): - raise ValueError( - f'Expected a `tff.StructType`, found {fn.type_signature.parameter}.' - ) - param_elements = fn.type_signature.parameter.items() - references = [] - for idx, (_, elem_type) in enumerate(param_elements): - references.append(building_blocks.Reference('arg{}'.format(idx), elem_type)) - result = building_blocks.Call(fn.comp, building_blocks.Struct(references)) - for ref in references[::-1]: - result = building_blocks.Lambda(ref.name, ref.type_signature, result) - return value_impl.Value(result) - - -def ensure_federated_value(value, placement=None, label=None): - """Ensures `value` is a federated value placed at `placement`. - - If `value` is not a `computation_types.FederatedType` but is a - `computation_types.StructType` that can be converted via `federated_zip` - to a `computation_types.FederatedType`, inserts the call to `federated_zip` - and returns the result. If `value` cannot be converted, raises a TypeError. - - Args: - value: A `value_impl.Value` to check and convert to a federated value if - possible. - placement: The expected placement. If None, any placement is allowed. - label: An optional string label that describes `value`. - - Returns: - The value as a federated value, automatically zipping if necessary. - - Raises: - TypeError: if `value` is not a `FederatedType` and cannot be converted to - a `FederatedType` with `federated_zip`. - """ - py_typecheck.check_type(value, value_impl.Value) - if label is not None: - py_typecheck.check_type(label, str) - - if not isinstance(value.type_signature, computation_types.FederatedType): - comp = value.comp - try: - zipped = building_block_factory.create_federated_zip(comp) - except (TypeError, ValueError) as e: - raise TypeError( - 'The {l} must be a FederatedType or implicitly convertible ' - 'to a FederatedType (got a {t}).'.format( - l=label if label else 'value', t=comp.type_signature - ) - ) from e - value = value_impl.Value(zipped) - - if placement is not None and value.type_signature.placement is not placement: # pytype: disable=attribute-error - raise TypeError( - 'The {} should be placed at {}, but it is placed at {}.'.format( - label if label else 'value', - placement, - value.type_signature.placement, # pytype: disable=attribute-error - ) - ) - - return value diff --git a/tensorflow_federated/python/core/impl/federated_context/value_utils_test.py b/tensorflow_federated/python/core/impl/federated_context/value_utils_test.py deleted file mode 100644 index b92af6bb72..0000000000 --- a/tensorflow_federated/python/core/impl/federated_context/value_utils_test.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import computation_factory -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation_context -from tensorflow_federated.python.core.impl.federated_context import value_impl -from tensorflow_federated.python.core.impl.federated_context import value_utils -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -class ValueUtilsTest(parameterized.TestCase): - - def run(self, result=None): - fc_context = federated_computation_context.FederatedComputationContext( - context_stack_impl.context_stack - ) - with context_stack_impl.context_stack.install(fc_context): - super().run(result) - - def test_get_curried(self): - type_sec = computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ]) - computation_proto = computation_factory.create_lambda_identity(type_sec) - type_signature = computation_types.FunctionType(type_sec, type_sec) - building_block = building_blocks.CompiledComputation( - proto=computation_proto, name='test', type_signature=type_signature - ) - value = value_impl.Value(building_block) - - curried = value_utils.get_curried(value) - - self.assertEqual( - curried.type_signature.compact_representation(), - '(int32 -> (int32 -> ))', - ) - self.assertEqual( - curried.comp.compact_representation(), - '(arg0 -> (arg1 -> comp#test()))', - ) - - def test_ensure_federated_value(self): - - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) - ) - def _(x): - x = value_impl.to_value(x, type_spec=None) - value_utils.ensure_federated_value(x, placements.CLIENTS) - return x - - def test_ensure_federated_value_wrong_placement(self): - - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) - ) - def _(x): - x = value_impl.to_value(x, type_spec=None) - with self.assertRaises(TypeError): - value_utils.ensure_federated_value(x, placements.SERVER) - return x - - def test_ensure_federated_value_implicitly_zippable(self): - - @federated_computation.federated_computation( - computation_types.StructType(( - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.CLIENTS), - )) - ) - def _(x): - x = value_impl.to_value(x, type_spec=None) - value_utils.ensure_federated_value(x) - return x - - def test_ensure_federated_value_fails_on_unzippable(self): - - @federated_computation.federated_computation( - computation_types.StructType(( - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.SERVER), - )) - ) - def _(x): - x = value_impl.to_value(x, type_spec=None) - with self.assertRaises(TypeError): - value_utils.ensure_federated_value(x) - return x - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/BUILD b/tensorflow_federated/python/core/impl/types/BUILD index edd80f4208..542d565782 100644 --- a/tensorflow_federated/python/core/impl/types/BUILD +++ b/tensorflow_federated/python/core/impl/types/BUILD @@ -1,4 +1,4 @@ -load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = ["//:package_license"], @@ -6,11 +6,9 @@ package( ":types_packages", "//tensorflow_federated/python/core/impl:impl_users", "//tensorflow_federated/python/core/impl/compiler:compiler_packages", - "//tensorflow_federated/python/core/impl/computation:computation_packages", "//tensorflow_federated/python/core/impl/execution_contexts:execution_contexts_packages", "//tensorflow_federated/python/core/impl/executor_stacks:executor_stacks_packages", "//tensorflow_federated/python/core/impl/executors:executors_packages", - "//tensorflow_federated/python/core/impl/federated_context:federated_context_packages", ], ) @@ -25,217 +23,5 @@ py_library( name = "types", srcs = ["__init__.py"], visibility = ["//tensorflow_federated:__pkg__"], - deps = [ - ":array_shape", - ":computation_types", - ":type_analysis", - ":type_conversions", - ":type_serialization", - ], -) - -py_library( - name = "array_shape", - srcs = ["array_shape.py"], - deps = [ - "//tensorflow_federated/proto/v0:array_py_pb2", - "//tensorflow_federated/proto/v0:data_type_py_pb2", - ], -) - -py_test( - name = "array_shape_test", - srcs = ["array_shape_test.py"], - deps = [ - ":array_shape", - "//tensorflow_federated/proto/v0:array_py_pb2", - ], -) - -py_library( - name = "computation_types", - srcs = ["computation_types.py"], - deps = [ - ":array_shape", - ":dtype_utils", - ":placements", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_test( - name = "computation_types_test", - size = "small", - srcs = ["computation_types_test.py"], - args = [ - "--golden", - "$(location computation_types_test_goldens/container_types_full_repr.expected)", - "--golden", - "$(location computation_types_test_goldens/long_formatted_with_diff.expected)", - "--golden", - "$(location computation_types_test_goldens/short_compact_repr.expected)", - ], - data = [ - "computation_types_test_goldens/container_types_full_repr.expected", - "computation_types_test_goldens/long_formatted_with_diff.expected", - "computation_types_test_goldens/short_compact_repr.expected", - ], - deps = [ - ":computation_types", - ":placements", - "//tensorflow_federated/python/common_libs:golden", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_library( - name = "dtype_utils", - srcs = ["dtype_utils.py"], - deps = ["//tensorflow_federated/proto/v0:data_type_py_pb2"], -) - -py_test( - name = "dtype_utils_test", - srcs = ["dtype_utils_test.py"], - deps = [":dtype_utils"], -) - -py_library( - name = "placements", - srcs = ["placements.py"], -) - -py_test( - name = "placements_test", - size = "small", - srcs = ["placements_test.py"], - deps = [":placements"], -) - -py_library( - name = "type_analysis", - srcs = ["type_analysis.py"], - deps = [ - ":array_shape", - ":computation_types", - ":placements", - ":type_conversions", - ":type_transformations", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_test( - name = "type_analysis_test", - size = "small", - srcs = ["type_analysis_test.py"], - deps = [ - ":computation_types", - ":placements", - ":type_analysis", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_library( - name = "type_conversions", - srcs = ["type_conversions.py"], - deps = [ - ":computation_types", - ":dtype_utils", - ":typed_object", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_test( - name = "type_conversions_test", - size = "small", - srcs = ["type_conversions_test.py"], - deps = [ - ":computation_types", - ":placements", - ":type_conversions", - ":typed_object", - "//tensorflow_federated/python/common_libs:structure", - ], -) - -py_library( - name = "type_factory", - srcs = ["type_factory.py"], - deps = [":computation_types"], -) - -py_test( - name = "type_factory_test", - size = "small", - srcs = ["type_factory_test.py"], - deps = [ - ":computation_types", - ":type_factory", - ], -) - -py_library( - name = "type_serialization", - srcs = ["type_serialization.py"], - deps = [ - ":array_shape", - ":computation_types", - ":dtype_utils", - ":placements", - "//tensorflow_federated/proto/v0:array_py_pb2", - "//tensorflow_federated/proto/v0:computation_py_pb2", - ], -) - -py_test( - name = "type_serialization_test", - size = "small", - srcs = ["type_serialization_test.py"], - deps = [ - ":array_shape", - ":computation_types", - ":dtype_utils", - ":placements", - ":type_serialization", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/proto/v0:data_type_py_pb2", - ], -) - -py_library( - name = "type_test_utils", - srcs = ["type_test_utils.py"], - deps = [":computation_types"], -) - -py_library( - name = "type_transformations", - srcs = ["type_transformations.py"], - deps = [ - ":computation_types", - "//tensorflow_federated/python/common_libs:py_typecheck", - ], -) - -py_test( - name = "type_transformations_test", - size = "small", - srcs = ["type_transformations_test.py"], - deps = [ - ":computation_types", - ":placements", - ":type_transformations", - ], -) - -py_library( - name = "typed_object", - srcs = ["typed_object.py"], - deps = [":computation_types"], + deps = ["@federated_language//federated_language"], ) diff --git a/tensorflow_federated/python/core/impl/types/__init__.py b/tensorflow_federated/python/core/impl/types/__init__.py index 418bc15420..30e80e2444 100644 --- a/tensorflow_federated/python/core/impl/types/__init__.py +++ b/tensorflow_federated/python/core/impl/types/__init__.py @@ -12,35 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. """Libraries for interacting with the type of a computation.""" +import federated_language # pylint: disable=g-importing-member -from tensorflow_federated.python.core.impl.types.array_shape import ArrayShape -from tensorflow_federated.python.core.impl.types.array_shape import is_shape_fully_defined -from tensorflow_federated.python.core.impl.types.array_shape import num_elements_in_shape -from tensorflow_federated.python.core.impl.types.computation_types import AbstractType -from tensorflow_federated.python.core.impl.types.computation_types import FederatedType -from tensorflow_federated.python.core.impl.types.computation_types import FunctionType -from tensorflow_federated.python.core.impl.types.computation_types import PlacementType -from tensorflow_federated.python.core.impl.types.computation_types import SequenceType -from tensorflow_federated.python.core.impl.types.computation_types import StructType -from tensorflow_federated.python.core.impl.types.computation_types import StructWithPythonType -from tensorflow_federated.python.core.impl.types.computation_types import TensorType -from tensorflow_federated.python.core.impl.types.computation_types import to_type -from tensorflow_federated.python.core.impl.types.computation_types import Type -from tensorflow_federated.python.core.impl.types.computation_types import type_mismatch_error_message -from tensorflow_federated.python.core.impl.types.computation_types import TypeNotAssignableError -from tensorflow_federated.python.core.impl.types.computation_types import TypeRelation -from tensorflow_federated.python.core.impl.types.computation_types import TypesNotEquivalentError -from tensorflow_federated.python.core.impl.types.computation_types import TypesNotIdenticalError -from tensorflow_federated.python.core.impl.types.computation_types import UnexpectedTypeError -from tensorflow_federated.python.core.impl.types.type_analysis import contains -from tensorflow_federated.python.core.impl.types.type_analysis import contains_only -from tensorflow_federated.python.core.impl.types.type_analysis import count -from tensorflow_federated.python.core.impl.types.type_analysis import is_structure_of_floats -from tensorflow_federated.python.core.impl.types.type_analysis import is_structure_of_integers -from tensorflow_federated.python.core.impl.types.type_analysis import is_structure_of_tensors -from tensorflow_federated.python.core.impl.types.type_analysis import is_tensorflow_compatible_type -from tensorflow_federated.python.core.impl.types.type_conversions import type_to_py_container -from tensorflow_federated.python.core.impl.types.type_serialization import deserialize_type -from tensorflow_federated.python.core.impl.types.type_serialization import serialize_type +ArrayShape = federated_language.ArrayShape +is_shape_fully_defined = federated_language.array_shape_is_fully_defined +num_elements_in_shape = federated_language.num_elements_in_array_shape +AbstractType = federated_language.AbstractType +FederatedType = federated_language.FederatedType +FunctionType = federated_language.FunctionType +PlacementType = federated_language.PlacementType +SequenceType = federated_language.SequenceType +StructType = federated_language.StructType +StructWithPythonType = federated_language.StructWithPythonType +TensorType = federated_language.TensorType +to_type = federated_language.to_type +Type = federated_language.Type +type_mismatch_error_message = ( + federated_language.framework.type_mismatch_error_message +) +TypeNotAssignableError = federated_language.framework.TypeNotAssignableError +TypeRelation = federated_language.framework.TypeRelation +TypesNotEquivalentError = federated_language.framework.TypesNotEquivalentError +TypesNotIdenticalError = federated_language.framework.TypesNotIdenticalError +UnexpectedTypeError = federated_language.framework.UnexpectedTypeError +contains = federated_language.framework.type_contains +contains_only = federated_language.framework.type_contains_only +count = federated_language.framework.type_count +is_structure_of_floats = federated_language.framework.is_structure_of_floats +is_structure_of_integers = federated_language.framework.is_structure_of_integers +is_structure_of_tensors = federated_language.framework.is_structure_of_tensors +is_tensorflow_compatible_type = ( + federated_language.framework.is_tensorflow_compatible_type +) +type_to_py_container = federated_language.framework.type_to_py_container +deserialize_type = federated_language.framework.deserialize_type +serialize_type = federated_language.framework.serialize_type # pylint: disable=g-importing-member diff --git a/tensorflow_federated/python/core/impl/types/array_shape.py b/tensorflow_federated/python/core/impl/types/array_shape.py deleted file mode 100644 index 00c57ee0a3..0000000000 --- a/tensorflow_federated/python/core/impl/types/array_shape.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2023, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for working with shapes. - -The shape of an `Array` may be one of the following: - -* Fully-defined: Has a known number of dimensions and a known size for each - dimension (e.g. (2, 3)). -* Partially-defined: Has a known number of dimensions, and an unknown size for - one or more dimension (e.g. (2, None)). -* Unknown: Has an unknown number of dimensions (e.g. None). -* Scalar: Has no dimensions (e.g. ()). -""" - -from collections.abc import Sequence -import functools -import operator -from typing import Optional, Union - -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.proto.v0 import data_type_pb2 # pylint: disable=unused-import # b/330931277 - - -_EmptyTuple = tuple[()] -_ArrayShapeLike = Union[Sequence[Optional[int]], None, _EmptyTuple] - -# ArrayShape is the Python representation of the `ArrayShape` protobuf, and is -# the shape of an `Array`. -ArrayShape = Union[tuple[Optional[int], ...], None, _EmptyTuple] - - -def from_proto(shape_pb: array_pb2.ArrayShape) -> ArrayShape: - """Returns a `tff.types.ArrayShape` for the `shape_pb`.""" - if shape_pb.unknown_rank: - return None - else: - return tuple(d if d >= 0 else None for d in shape_pb.dim) - - -def to_proto(shape: ArrayShape) -> array_pb2.ArrayShape: - """Returns an `ArrayShape` for the `shape`.""" - if shape is not None: - dims = [d if d is not None else -1 for d in shape] - return array_pb2.ArrayShape(dim=dims) - else: - return array_pb2.ArrayShape(unknown_rank=True) - - -def is_shape_fully_defined(shape: ArrayShape) -> bool: - """Returns `True` if `shape` is fully defined, False otherwise. - - Args: - shape: A `tff.types.ArrayShape`. - """ - return shape is not None and all(dim is not None for dim in shape) - - -def is_shape_scalar(shape: ArrayShape) -> bool: - """Returns `True` if `shape` is scalar, False otherwise. - - Args: - shape: A `tff.types.ArrayShape`. - """ - return shape is not None and not shape - - -def is_compatible_with(target: ArrayShape, other: ArrayShape) -> bool: - """Returns `True` if `target` is compatible with `other`, otherwise `False`. - - Two shapes are compatible if there exists a fully-defined shape that both - shapes can represent. For example: - - * `None` is compatible with all shapes. - * `(None, None)` is compatible with all two-dimensional shapes, and also - `None`. - * `(2, None)` is compatible with all two-dimensional shapes with size 2 in the - 0th dimension, and also `(None, None)` and `None`. - * `(2, 3) is compatible with itself, and also `(32, None)`, `(None, 3]), - `(None, None)`, and `None`. - - Args: - target: A `tff.types.ArrayShape`. - other: Another `tff.types.ArrayShape`. - """ - if target is None or other is None: - return True - - if len(target) != len(other): - return False - - return all(x is None or y is None or x == y for x, y in zip(target, other)) - - -def num_elements_in_shape(shape: ArrayShape) -> Optional[int]: - """Returns the number of elements in `shape`, or `None` if not fully defined. - - Args: - shape: A `tff.types.ArrayShape`. - """ - if is_shape_fully_defined(shape): - return functools.reduce(operator.mul, shape, 1) - else: - return None diff --git a/tensorflow_federated/python/core/impl/types/array_shape_test.py b/tensorflow_federated/python/core/impl/types/array_shape_test.py deleted file mode 100644 index 6a05a941b7..0000000000 --- a/tensorflow_federated/python/core/impl/types/array_shape_test.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2023, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized - -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.python.core.impl.types import array_shape - - -class ArrayShapeTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('fully_defined', array_pb2.ArrayShape(dim=[2, 3]), (2, 3)), - ('partially_defined', array_pb2.ArrayShape(dim=[2, -1]), (2, None)), - ('unknown', array_pb2.ArrayShape(unknown_rank=True), None), - ('scalar_empty', array_pb2.ArrayShape(dim=[]), ()), - ('scalar_none', array_pb2.ArrayShape(), ()), - ) - def test_from_proto_returns_value(self, proto, expected_value): - actual_value = array_shape.from_proto(proto) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ('fully_defined', (2, 3), array_pb2.ArrayShape(dim=[2, 3])), - ('partially_defined', (2, None), array_pb2.ArrayShape(dim=[2, -1])), - ('unknown', None, array_pb2.ArrayShape(unknown_rank=True)), - ('scalar', (), array_pb2.ArrayShape(dim=[])), - ) - def test_to_proto_returns_value(self, shape, expected_value): - actual_value = array_shape.to_proto(shape) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ('fully_defined', (2, 3)), - ('scalar', ()), - ) - def test_is_shape_fully_defined_returns_true(self, shape): - result = array_shape.is_shape_fully_defined(shape) - self.assertTrue(result) - - @parameterized.named_parameters( - ('partially_defined', (2, None)), - ('unknown', None), - ) - def test_is_shape_fully_defined_returns_false(self, shape): - result = array_shape.is_shape_fully_defined(shape) - self.assertFalse(result) - - def test_is_shape_scalar_returns_true(self): - shape = () - result = array_shape.is_shape_scalar(shape) - self.assertTrue(result) - - @parameterized.named_parameters( - ('fully_defined', (2, 3)), - ('partially_defined', (2, None)), - ('unknown', None), - ) - def test_is_shape_scalar_returns_false(self, shape): - result = array_shape.is_shape_scalar(shape) - self.assertFalse(result) - - @parameterized.named_parameters( - ('fully_defined_and_fully_defined', (2, 3), (2, 3)), - ('fully_defined_and_partially_defined', (2, 3), (2, None)), - ('fully_defined_and_partially_defined_only_rank', (2, 3), (None, None)), - ('fully_defined_and_unknown', (2, 3), None), - ('partially_defined_and_fully_defined', (2, None), (2, 3)), - ('partially_defined_and_partially_defined', (2, None), (2, None)), - ( - 'partially_defined_and_partially_defined_only_rank', - (2, None), - (None, None), - ), - ('partially_defined_and_unknown', (2, None), None), - ('partially_defined_only_rank_and_fully_defined', (None, None), (2, 3)), - ( - 'partially_defined_only_rank_and_partially_defined', - (None, None), - (2, None), - ), - ( - 'partially_defined_only_rank_and_partially_defined_only_rank', - (None, None), - (None, None), - ), - ('partially_defined_only_rank_and_unknown', (None, None), None), - ('unknown_and_fully_defined', None, (2, 3)), - ('unknown_and_partially_defined', None, (2, None)), - ('unknown_and_partially_defined_only_rank', None, (None, None)), - ('unknown_and_unknown', None, None), - ('unknown_and_scalar', None, ()), - ('scalar_and_unknown', (), None), - ('scalar_and_scalar', (), ()), - ) - def test_is_compatible_with_returns_true(self, target, other): - result = array_shape.is_compatible_with(target, other) - self.assertTrue(result) - - @parameterized.named_parameters( - ('fully_defined_and_scalar', (2, 3), ()), - ('fully_defined_wrong_size', (2, 3), (20, 3)), - ('partially_defined_and_scalar', (2, None), ()), - ('partially_defined_wrong_size', (2, None), (20, None)), - ('partially_defined_only_rank_and_scalar', (None, None), ()), - ('scalar_and_fully_defined', (), (2, 3)), - ('scalar_and_partially_defined', (), (2, None)), - ('scalar_and_partially_defined_only_rank', (), (None, None)), - ) - def test_is_compatible_with_returns_false(self, target, other): - result = array_shape.is_compatible_with(target, other) - self.assertFalse(result) - - @parameterized.named_parameters( - ('fully_defined', (2, 3), 6), - ('partially_defined', (2, None), None), - ('unknown', None, None), - ('scalar', (), 1), - ) - def test_num_elements_in_shape_returns_result(self, shape, expected_result): - actual_result = array_shape.num_elements_in_shape(shape) - self.assertEqual(actual_result, expected_result) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/computation_types.py b/tensorflow_federated/python/core/impl/types/computation_types.py deleted file mode 100644 index 20048692bf..0000000000 --- a/tensorflow_federated/python/core/impl/types/computation_types.py +++ /dev/null @@ -1,1275 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines functions and classes for building and manipulating TFF types.""" - -import abc -import atexit -import collections -from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence -import difflib -import enum -from typing import Optional, TypeVar, Union -import weakref - -import attrs -import numpy as np -from typing_extensions import TypeGuard - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import dtype_utils -from tensorflow_federated.python.core.impl.types import placements - -T = TypeVar('T') - - -class UnexpectedTypeError(TypeError): - - def __init__(self, expected: type['Type'], actual: 'Type'): - message = f'Expected type of kind {expected}, found type {actual}' - super().__init__(message) - self.actual = actual - self.expected = expected - - -# Prevent wrapping on a 100-character terminal. -MAX_LINE_LEN = 100 - - -@enum.unique -class TypeRelation(enum.Enum): - EQUIVALENT = 'equivalent' - IDENTICAL = 'identical' - ASSIGNABLE = 'assignable' - - -def type_mismatch_error_message( - first: 'Type', - second: 'Type', - relation: TypeRelation, - second_is_expected: bool = False, -) -> str: - """Returns an error message describing the mismatch between two types.""" - maybe_expected = 'expected ' if second_is_expected else '' - first_str = first.compact_representation() - second_str = second.compact_representation() - diff = None - if first_str == second_str: - # The two only differ in container types or some other property not - # visible via the compact representation, so show `repr` instead. - # No diff is used because `repr` prints to a single line. - first_str = repr(first) - second_str = repr(second) - diff = None - elif len(first_str) > MAX_LINE_LEN or len(second_str) > MAX_LINE_LEN: - # The types are large structures, and so the formatted representation is - # used and a summary diff is added. The logic here is that large types - # may be easier to diff visually with a more structured representation, - # and logical line breaks are required to make diff output useful. - first_str = first.formatted_representation() - second_str = second.formatted_representation() - split_first = first_str.split('\n') - split_second = second_str.split('\n') - diff = '\n'.join(difflib.unified_diff(split_first, split_second)) - message = [ - 'Type', - f'`{first_str}`', - f'is not {relation.value} to {maybe_expected}type', - f'`{second_str}`', - ] - if diff: - message += [f'\nDiff:\n{diff}'] - single_line = ' '.join(message) - if len(single_line) > MAX_LINE_LEN or '\n' in single_line: - return '\n'.join(message) - else: - return single_line - - -class TypeNotAssignableError(TypeError): - - def __init__(self, source_type, target_type): - self.message = type_mismatch_error_message( - source_type, target_type, TypeRelation.ASSIGNABLE - ) - super().__init__(self.message) - self.source_type = source_type - self.target_type = target_type - - -class TypesNotEquivalentError(TypeError): - - def __init__(self, first_type, second_type): - self.message = type_mismatch_error_message( - first_type, second_type, TypeRelation.EQUIVALENT - ) - super().__init__(self.message) - self.first_type = first_type - self.second_type = second_type - - -class TypesNotIdenticalError(TypeError): - - def __init__(self, first_type, second_type): - self.message = type_mismatch_error_message( - first_type, second_type, TypeRelation.IDENTICAL - ) - super().__init__(self.message) - self.first_type = first_type - self.second_type = second_type - - -class Type(metaclass=abc.ABCMeta): - """An abstract interface for all classes that represent TFF types.""" - - def compact_representation(self) -> str: - """Returns the compact string representation of this type.""" - return _string_representation(self, formatted=False) - - def formatted_representation(self) -> str: - """Returns the formatted string representation of this type.""" - return _string_representation(self, formatted=True) - - @abc.abstractmethod - def children(self) -> Iterator['Type']: - """Returns a generator yielding immediate child types.""" - raise NotImplementedError - - @abc.abstractmethod - def __repr__(self): - """Returns a full-form representation of this type.""" - raise NotImplementedError - - def __str__(self): - """Returns a concise representation of this type.""" - return self.compact_representation() - - @abc.abstractmethod - def __hash__(self): - """Produces a hash value for this type.""" - raise NotImplementedError - - @abc.abstractmethod - def __eq__(self, other): - """Determines whether two type definitions are identical. - - Note that this notion of equality is stronger than equivalence. Two types - with equivalent definitions may not be identical, e.g., if they represent - templates with differently named type variables in their definitions. - - Args: - other: The other type to compare against. - - Returns: - `True` if type definitions are syntactically identical (as defined above), - otherwise `False`. - - Raises: - NotImplementedError: If not implemented in the derived class. - """ - raise NotImplementedError - - def __ne__(self, other): - return not self == other - - def check_assignable_from(self, source_type: 'Type') -> None: - """Raises if values of `source_type` cannot be cast to this type.""" - if not self.is_assignable_from(source_type): - raise TypeNotAssignableError(source_type=source_type, target_type=self) - - @abc.abstractmethod - def is_assignable_from(self, source_type: 'Type') -> bool: - """Returns whether values of `source_type` can be cast to this type.""" - raise NotImplementedError - - def check_equivalent_to(self, other: 'Type') -> None: - """Raises if values of 'other' cannot be cast to and from this type.""" - if not self.is_equivalent_to(other): - raise TypesNotEquivalentError(self, other) - - def is_equivalent_to(self, other: 'Type') -> bool: - """Returns whether values of `other` can be cast to and from this type.""" - return self.is_assignable_from(other) and other.is_assignable_from(self) - - def check_identical_to(self, other: 'Type') -> None: - """Raises if `other` and `Type` are not exactly identical.""" - if not self.is_identical_to(other): - raise TypesNotIdenticalError(self, other) - - def is_identical_to(self, other: 'Type') -> bool: - """Returns whether or not `self` and `other` are exactly identical.""" - return self == other - - -class _Intern(abc.ABCMeta): - """A metaclass which interns instances. - - This is used to create classes where the following predicate holds: - `MyClass(some_args) is MyClass(some_args)` - - That is, objects of the class with the same constructor parameters result - in values with the same object identity. This can make comparison of deep - structures much cheaper, since a shallow equality check can short-circuit - comparison. - - Classes which set `_Intern` as a metaclass must have a - `_hashable_from_init_args` classmethod which defines exactly the parameters - passed to the `__init__` method. If one of the parameters passed to the - `_Intern.__call__` is an iterator it will be converted to a list before - `_hashable_from_init_args` and `__init__` are called. - - Note: also that this metaclass must only be used with *immutable* values, as - mutation would cause all similarly-constructed instances to be mutated - together. - - Inherits from `abc.ABCMeta` to prevent subclass conflicts. - """ - - @classmethod - def _hashable_from_init_args(mcs, *args, **kwargs) -> Hashable: - raise NotImplementedError - - def __call__(cls, *args, **kwargs): - - # Convert all `Iterator`s in both `args` and `kwargs` to `list`s so they can - # be used in both `_hashable_from_init_args` and `__init__`. - def _normalize(obj): - if isinstance(obj, Iterator): - return list(obj) - else: - return obj - - args = [_normalize(x) for x in args] - kwargs = {k: _normalize(v) for k, v in kwargs.items()} - - # Salt the key with `cls` to account for two different classes that return - # the same result from `_hashable_from_init_args`. - key = (cls, cls._hashable_from_init_args(*args, **kwargs)) - intern_pool = _intern_pool[cls] - instance = intern_pool.get(key, None) - if instance is None: - instance = super().__call__(*args, **kwargs) - intern_pool[key] = instance - return instance - - -# A per-`typing.Type` map from `__init__` arguments to object instances. -# -# This is used by the `_Intern` metaclass to allow reuse of object instances -# when new objects are requested with the same `__init__` arguments as -# existing object instances. -# -# Implementation note: this double-map is used rather than a single map -# stored as a field of each class because some class objects themselves would -# begin destruction before the map fields of other classes, causing errors -# during destruction. -_intern_pool: MutableMapping[type[Type], MutableMapping[Hashable, Type]] = ( - collections.defaultdict(dict) -) - - -def _clear_intern_pool() -> None: - # We must clear our `WeakKeyValueDictionary`s at the end of the program to - # prevent Python from deleting the standard library out from under us before - # removing the entries from the dictionary. Yes, this is cursed. - # - # If this isn't done, Python will call `__eq__` on our types after - # `abc.ABCMeta` has already been deleted from the world, resulting in - # exceptions after main. - global _intern_pool - _intern_pool = None - - -atexit.register(_clear_intern_pool) - - -_DtypeLike = Union[type[np.generic], np.dtype] - - -def _is_dtype_like(obj: object) -> TypeGuard[_DtypeLike]: - """Returns `True` if `obj` is dtype like, otherwise `False`.""" - if isinstance(obj, type) and issubclass(obj, np.generic): - return True - else: - return isinstance(obj, np.dtype) - - -def _is_array_shape_like( - obj: object, -) -> TypeGuard[Union[array_shape._ArrayShapeLike]]: - """Returns `True` if `obj` is an `_ArrayShapeLike`, otherwise `False`.""" - if obj is None: - return True - elif isinstance(obj, Sequence): - # If iterating over the `Sequence` fails, then `obj` is not an - # `array_shape._ArrayShapeLike`. - try: - return all(isinstance(x, int) or x is None for x in obj) - except Exception: # pylint: disable=broad-exception-caught - return False - else: - return False - - -def _to_dtype(dtype: _DtypeLike) -> np.dtype: - """Returns a `np.dtype` for the dtype like object. - - Normalize `dtype` to an instance of `np.dtype` that describes an array - scalar. see https://numpy.org/doc/stable/reference/arrays.scalars.html. - - Args: - dtype: A dtype like object. - """ - if isinstance(dtype, np.dtype): - dtype = dtype.type - if dtype is np.bytes_: - dtype = np.str_ - - if not dtype_utils.is_valid_dtype(dtype): - raise NotImplementedError(f'Unexpected `dtype` found: {dtype}.') - return np.dtype(dtype) - - -class TensorType(Type, metaclass=_Intern): - """An implementation of `tff.Type` representing types of tensors in TFF.""" - - @classmethod - def _hashable_from_init_args( - cls, - dtype: _DtypeLike, - shape: array_shape._ArrayShapeLike = (), - ) -> Hashable: - """Returns hashable `TensorType.__init__` args.""" - dtype = _to_dtype(dtype) - if shape is not None: - shape = tuple(shape) - return (dtype, shape) - - def __init__( - self, - dtype: _DtypeLike, - shape: array_shape._ArrayShapeLike = (), - ): - """Constructs a new instance from the given `dtype` and `shape`. - - Args: - dtype: The `np.dtype` of the array. - shape: The shape of the array. - - Raises: - TypeError: if arguments are of the wrong types. - """ - self._dtype = _to_dtype(dtype) - if shape is not None: - shape = tuple(shape) - self._shape = shape - - def children(self) -> Iterator[Type]: - return iter(()) - - @property - def dtype(self) -> np.dtype: - return self._dtype - - @property - def shape(self) -> array_shape.ArrayShape: - return self._shape - - def __repr__(self): - dtype_repr = f'np.{self._dtype.type.__name__}' - if array_shape.is_shape_scalar(self._shape): - return f'TensorType({dtype_repr})' - else: - return f'TensorType({dtype_repr}, {self._shape!r})' - - def __hash__(self): - return hash((self._dtype, self._shape)) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, TensorType) - and self._dtype == other.dtype - and self._shape == other.shape - ) - - def is_assignable_from(self, source_type: Type) -> bool: - if self is source_type: - return True - - if ( - not isinstance(source_type, TensorType) - or self.dtype != source_type.dtype - ): - return False - - target_shape = self.shape - source_shape = source_type.shape - if target_shape is None: - return True - elif source_shape is None: - return False - - if len(target_shape) != len(source_shape): - return False - - def _dimension_is_assignable_from(target_dim, source_dim): - return (target_dim is None) or (target_dim == source_dim) - - return all( - _dimension_is_assignable_from(x, y) - for x, y in zip(target_shape, source_shape) - ) - - -def _format_struct_type_members(struct_type: 'StructType') -> str: - def _element_repr(element): - name, value = element - if name is not None: - return "('{}', {!r})".format(name, value) - return repr(value) - - return ', '.join(_element_repr(e) for e in struct_type.items()) - - -def _to_named_types( - elements: Iterable[object], -) -> Sequence[tuple[Optional[str], Type]]: - """Creates an `Iterable` of optionally named types from `elements`. - - This function creates an `Iterable` of optionally named types by iterating - over `elements` and normalizing each element. - - If `elements` is an `Iterable` with named elements (e.g. `Mapping` or - `NamedTuple`), the normalize element will have a name equal to the name of the - element and a value equal to the value of the element convereted to a type - using `to_type`. - - If `elements` is an `Iterable` with unnamed elements (e.g. list), the - normalized element will have a name of `None` and a value equal to the element - convereted to a type using `to_type`. - - Note: This function treats a single element being passed in as `elements` as - if it were an iterable of that element. - - Args: - elements: An iterable of named or unnamed objects to convert to `tff.Types`. - See `tff.types.to_type` for more information. - - Returns: - A `Sequence` where each each element is `tuple[Optional[str], Type]`. - """ - - if py_typecheck.is_name_value_pair(elements): - elements = [elements] - elif isinstance(elements, py_typecheck.SupportsNamedTuple): - elements = elements._asdict().items() - elif isinstance(elements, Mapping): - elements = elements.items() - - def _to_named_value_pair(element: object) -> tuple[Optional[str], Type]: - if py_typecheck.is_name_value_pair(element): - name, value = element - else: - name = None - value = element - value = to_type(value) - return (name, value) - - return [_to_named_value_pair(x) for x in elements] - - -def _reserved_names_in_elements( - elements: Sequence[tuple[Optional[str], object]], - reserved_names: Sequence[str], -) -> set[str]: - element_names = {n for n, _ in elements if n is not None} - return set(reserved_names).intersection(element_names) - - -class StructType(structure.Struct, Type, metaclass=_Intern): - """An implementation of `tff.Type` representing structural types in TFF. - - Elements initialized by name can be accessed as `foo.name`, and otherwise by - index, `foo[index]`. - - Elements can not be given names that would conflict with the methods and on - this class. - """ - - @classmethod - def _hashable_from_init_args( - cls, - elements: Iterable[object], - *, - convert: bool = True, - ) -> Hashable: - if convert: - elements = _to_named_types(elements) - invalid_names = _reserved_names_in_elements(elements, dir(cls)) - if invalid_names: - raise ValueError( - 'Expected named elements to not match any reserved names, found' - f' {invalid_names}.' - ) - return (tuple(elements), convert) - - def __init__( - self, - elements: Iterable[object], - *, - convert: bool = True, - ): - """Constructs a new instance from the given element types. - - Args: - elements: An iterable of element specifications. Each element - specification is either a type spec (an instance of `tff.Type` or - something convertible to it via `tff.types.to_type`) for the element, or - a (name, spec) for elements that have defined names. Alternatively, one - can supply here an instance of `collections.OrderedDict` mapping element - names to their types (or things that are convertible to types). - convert: A flag to determine if the elements should be converted using - `tff.types.to_type` or not. - """ - if convert: - elements = _to_named_types(elements) - structure.Struct.__init__(self, elements) - - def children(self) -> Iterator[Type]: - return (element for _, element in self.items()) - - @property - def python_container(self) -> Optional[type[object]]: - return None - - def items(self) -> Iterator[tuple[Optional[str], Type]]: - return structure.iter_elements(self) - - def __repr__(self): - members = _format_struct_type_members(self) - return f'StructType([{members}])' - - def __hash__(self): - # Salt to avoid overlap. - return hash((structure.Struct.__hash__(self), 'NTT')) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, StructType) and structure.Struct.__eq__(self, other) - ) - - def is_assignable_from(self, source_type: Type) -> bool: - if self is source_type: - return True - if not isinstance(source_type, StructType): - return False - target_elements = list(self.items()) - source_elements = list(source_type.items()) - if len(target_elements) != len(source_elements): - return False - for (target_name, target_element), (source_name, source_element) in zip( - target_elements, source_elements - ): - if source_name is not None and source_name != target_name: - return False - if not target_element.is_assignable_from(source_element): - return False - return True - - -class StructWithPythonType(StructType, metaclass=_Intern): - """A representation of a structure paired with a Python container type. - - Elements can not be given names that would conflict with the methods and on - this class. - """ - - @classmethod - def _hashable_from_init_args( - cls, elements: Iterable[object], container_type: type[object] - ) -> Hashable: - elements = _to_named_types(elements) - invalid_names = _reserved_names_in_elements(elements, dir(cls)) - if invalid_names: - raise ValueError( - 'Expected named elements to not match any reserved names, found' - f' {invalid_names}.' - ) - return (tuple(elements), container_type) - - def __init__(self, elements: Iterable[object], container_type: type[object]): - super().__init__(elements) - self._container_type = container_type - - @property - def python_container(self) -> type[object]: - return self._container_type - - def __repr__(self): - members = _format_struct_type_members(self) - return 'StructType([{}]) as {}'.format( - members, self._container_type.__name__ - ) - - def __hash__(self): - # Salt to avoid overlap. - return hash((structure.Struct.__hash__(self), 'NTTWPCT')) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, StructWithPythonType) - and (self._container_type == other._container_type) - and structure.Struct.__eq__(self, other) - ) - - -class SequenceType(Type, metaclass=_Intern): - """An implementation of `tff.Type` representing types of sequences in TFF. - - IMPORTANT: since `SequenceType` is frequently backed by `tf.data.Dataset` - which converts `list` to `tuple`, any `SequenceType` constructed with - `StructWithPythonType` elements will convert any `list` python container type - to `tuple` python container types for interoperability. - """ - - @classmethod - def _hashable_from_init_args(cls, element: object) -> Hashable: - element = to_type(element) - return (element,) - - def __init__(self, element: object): - """Constructs a new instance from the given `element` type. - - Args: - element: A specification of the element type, either an instance of - `tff.Type` or something convertible to it by `tff.types.to_type`. - """ - - def convert_struct_with_list_to_struct_with_tuple(type_spec: T) -> T: - """Convert any StructWithPythonType using lists to use tuples.""" - # We ignore non-struct, non-tensor types, these are not well formed types - # for sequence elements. - if not isinstance(type_spec, StructType): - return type_spec - elements = [ - (name, convert_struct_with_list_to_struct_with_tuple(value)) - for name, value in type_spec.items() - ] - if not isinstance(type_spec, StructWithPythonType): - return StructType(elements=elements) - container_cls = type_spec.python_container - return StructWithPythonType( - elements=elements, - container_type=tuple if container_cls is list else container_cls, - ) - - element = to_type(element) - self._element = convert_struct_with_list_to_struct_with_tuple(element) - - children_types = _get_contained_children_types(self) - if ( - children_types.federated - or children_types.function - or children_types.sequence - ): - raise ValueError( - 'Expected a `tff.SequenceType` to not contain `tff.FederatedType`s, ' - f'`tff.FunctionType`s, or `tff.SequenceType`s, found {self}.' - ) - - def children(self) -> Iterator[Type]: - yield self._element - - @property - def element(self) -> Type: - return self._element - - def __repr__(self): - return 'SequenceType({!r})'.format(self._element) - - def __hash__(self): - return hash(self._element) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, SequenceType) and self._element == other.element - ) - - def is_assignable_from(self, source_type: Type) -> bool: - if self is source_type: - return True - return isinstance( - source_type, SequenceType - ) and self.element.is_assignable_from(source_type.element) - - -class FunctionType(Type, metaclass=_Intern): - """An implementation of `tff.Type` representing functional types in TFF.""" - - @classmethod - def _hashable_from_init_args( - cls, parameter: Optional[object], result: object - ) -> Hashable: - if parameter is not None: - parameter = to_type(parameter) - result = to_type(result) - return (parameter, result) - - def __init__(self, parameter: Optional[object], result: object): - """Constructs a new instance from the given `parameter` and `result` types. - - Args: - parameter: A specification of the parameter type, either an instance of - `tff.Type` or something convertible to it by `tff.types.to_type`. - Multiple input arguments can be specified as a single `tff.StructType`. - result: A specification of the result type, either an instance of - `tff.Type` or something convertible to it by `tff.types.to_type`. - """ - if parameter is not None: - parameter = to_type(parameter) - self._parameter = parameter - self._result = to_type(result) - - def children(self) -> Iterator[Type]: - if self._parameter is not None: - yield self._parameter - yield self._result - - @property - def parameter(self) -> Optional[Type]: - return self._parameter - - @property - def result(self) -> Type: - return self._result - - def __repr__(self): - return 'FunctionType({!r}, {!r})'.format(self._parameter, self._result) - - def __hash__(self): - return hash((self._parameter, self._result)) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, FunctionType) - and self._parameter == other.parameter - and self._result == other.result - ) - - def is_assignable_from(self, source_type: Type) -> bool: - if self is source_type: - return True - if not isinstance(source_type, FunctionType): - return False - if (self.parameter is None) != (source_type.parameter is None): - return False - # Note that function parameters are contravariant, so we invert the check. - if ( - self.parameter is not None - and not source_type.parameter.is_assignable_from(self.parameter) - ): - return False - return self.result.is_assignable_from(source_type.result) - - -class AbstractType(Type, metaclass=_Intern): - """An implementation of `tff.Type` representing abstract types in TFF.""" - - @classmethod - def _hashable_from_init_args(cls, label: str) -> Hashable: - return (label,) - - def __init__(self, label: str): - """Constructs a new instance from the given string `label`. - - Args: - label: A string label of an abstract type. All occurrences of the label - within a computation's type signature refer to the same concrete type. - """ - self._label = label - - def children(self) -> Iterator[Type]: - return iter(()) - - @property - def label(self) -> str: - return self._label - - def __repr__(self): - return "AbstractType('{}')".format(self._label) - - def __hash__(self): - return hash(self._label) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, AbstractType) and self._label == other.label - ) - - def is_assignable_from(self, source_type: Type) -> bool: - del source_type # Unused. - # TODO: b/113112108 - Revise this to extend the relation of assignability to - # abstract types. - raise TypeError('Abstract types are not comparable.') - - -class PlacementType(Type, metaclass=_Intern): - """An implementation of `tff.Type` representing the placement type in TFF. - - There is only one placement type, a TFF built-in, just as there is only one - `int` or `str` type in Python. All instances of this class represent the same - built-in TFF placement type. - """ - - @classmethod - def _hashable_from_init_args(cls, *args, **kwargs) -> Hashable: - del args, kwargs # Unused. - return () - - def children(self) -> Iterator[Type]: - return iter(()) - - def __repr__(self): - return 'PlacementType()' - - def __hash__(self): - return 0 - - def __eq__(self, other): - return (self is other) or isinstance(other, PlacementType) - - def is_assignable_from(self, source_type: Type) -> bool: - if self is source_type: - return True - return isinstance(source_type, PlacementType) - - -class FederatedType(Type, metaclass=_Intern): - """An implementation of `tff.Type` representing federated types in TFF.""" - - @classmethod - def _hashable_from_init_args( - cls, - member: object, - placement: placements.PlacementLiteral, - all_equal: Optional[bool] = None, - ) -> Hashable: - member = to_type(member) - return (member, placement, all_equal) - - def __init__( - self, - member: object, - placement: placements.PlacementLiteral, - all_equal: Optional[bool] = None, - ): - """Constructs a new federated type instance. - - Args: - member: An instance of `tff.Type` or something convertible to it, that - represents the type of the member components of each value of this - federated type. - placement: The specification of placement that the member components of - this federated type are hosted on. Must be either a placement literal - such as `tff.SERVER` or `tff.CLIENTS` to refer to a globally defined - placement, or a placement label to refer to a placement defined in other - parts of a type signature. Specifying placement labels is not - implemented yet. - all_equal: A `bool` value that indicates whether all members of the - federated type are equal (`True`), or are allowed to differ (`False`). - If `all_equal` is `None`, the value is selected as the default for the - placement, e.g., `True` for `tff.SERVER` and `False` for `tff.CLIENTS`. - """ - self._member = to_type(member) - self._placement = placement - if all_equal is None: - all_equal = placement.default_all_equal - self._all_equal = all_equal - - children_types = _get_contained_children_types(self) - if children_types.federated or children_types.function: - raise ValueError( - 'Expected a `tff.FederatedType` to not contain `tff.FederatedType`s ' - f'or `tff.FunctionType`s, found {self}.' - ) - - # TODO: b/113112108 - Extend this to support federated types parameterized - # by abstract placement labels, such as those used in generic types of - # federated operators. - - def children(self) -> Iterator[Type]: - yield self._member - - @property - def member(self) -> Type: - return self._member - - @property - def placement(self) -> placements.PlacementLiteral: - return self._placement - - @property - def all_equal(self) -> bool: - return self._all_equal - - def __repr__(self): - return 'FederatedType({!r}, {!r}, {!r})'.format( - self._member, self._placement, self._all_equal - ) - - def __hash__(self): - return hash((self._member, self._placement, self._all_equal)) - - def __eq__(self, other): - return (self is other) or ( - isinstance(other, FederatedType) - and self._member == other.member - and self._placement == other.placement - and self._all_equal == other.all_equal - ) - - def is_assignable_from(self, source_type: Type) -> bool: - if self is source_type: - return True - return ( - isinstance(source_type, FederatedType) - and self.member.is_assignable_from(source_type.member) - and (not self.all_equal or source_type.all_equal) - and self.placement is source_type.placement - ) - - -def to_type(obj: object) -> Type: - """Converts the argument into an instance of `tff.Type`. - - Examples of arguments convertible to tensor types: - - ```python - np.int32 - (np.int32, [10]) - (np.int32, [None]) - ``` - - Examples of arguments convertible to flat named tuple types: - - ```python - [np.int32, np.bool] - (np.int32, np.bool) - [('a', np.int32), ('b', np.bool)] - ('a', np.int32) - collections.OrderedDict([('a', np.int32), ('b', np.bool)]) - ``` - - Examples of arguments convertible to nested named tuple types: - - ```python - (np.int32, (np.float32, np.bool)) - (np.int32, (('x', np.float32), np.bool)) - ((np.int32, [1]), (('x', (np.float32, [2])), (np.bool, [3]))) - ``` - - `attr.s` class instances can also be used to describe TFF types by populating - the fields with the corresponding types: - - ```python - @attr.s(auto_attribs=True) - class MyDataClass: - int_scalar - string_array - - obj = MyDataClass(...) - type_spec = tff.types.to_type(obj) - - @tff.tensorflow.computation(type_spec) - def work(my_data): - assert isinstance(my_data, MyDataClass) - ... - ``` - - Args: - obj: Either an instance of `tff.Type`, or an argument convertible to - `tff.Type`. - - Returns: - An instance of `tff.Type` corresponding to the given `obj`. - """ - # TODO: b/113112108 - Add multiple examples of valid type specs here in the - # comments, in addition to the unit test. - if isinstance(obj, Type): - return obj - elif _is_dtype_like(obj): - return TensorType(obj) # pytype: disable=wrong-arg-types # b/290661340 - elif ( - isinstance(obj, tuple) - and len(obj) == 2 - and _is_dtype_like(obj[0]) - and _is_array_shape_like(obj[1]) - ): - dtype, shape = obj - return TensorType(dtype, shape) - elif isinstance(obj, (list, tuple)): - if any(py_typecheck.is_name_value_pair(e, name_type=str) for e in obj): - # The sequence has a (name, value) elements, the whole sequence is most - # likely intended to be a `Struct`, do not store the Python container. - return StructType(obj) - else: - return StructWithPythonType(obj, type(obj)) - elif attrs.has(type(obj)): - return StructWithPythonType(attrs.asdict(obj, recurse=False), type(obj)) - elif isinstance(obj, Mapping): - return StructWithPythonType(obj, type(obj)) - elif isinstance(obj, structure.Struct): - return StructType(structure.to_elements(obj)) - else: - raise TypeError( - 'Unable to interpret an argument of type {} as a type spec.'.format( - py_typecheck.type_string(type(obj)) - ) - ) - - -@attrs.define(frozen=True) -class _ContainedChildrenTypes: - """The types of children `tff.Types` contained by a `tff.Type`. - - This data structure is used by `_get_contained_children_types` to package - the types of children `tff.Types` contained by a `tff.Type` in a more - convenient way. - """ - - tensor: bool = False - struct: bool = False - struct_with_python_type: bool = False - sequence: bool = False - function: bool = False - abstract: bool = False - placement: bool = False - federated: bool = False - - -# Manual cache used rather than `cachetools.cached` due to incompatibility -# with `WeakKeyDictionary`. We want to use a `WeakKeyDictionary` so that -# cache entries are destroyed once the types they index no longer exist. -_contained_children_types_cache: MutableMapping[ - Type, _ContainedChildrenTypes -] = weakref.WeakKeyDictionary({}) - - -def _clear_contained_children_types_cache(): - # We must clear our `WeakKeyValueDictionary`s at the end of the program to - # prevent Python from deleting the standard library out from under us before - # removing the entries from the dictionary. Yes, this is cursed. - # - # If this isn't done, Python will call `__eq__` on our types after - # `abc.ABCMeta` has already been deleted from the world, resulting in - # exceptions after main. - global _contained_children_types_cache - _contained_children_types_cache = None - - -atexit.register(_clear_contained_children_types_cache) - - -def _get_contained_children_types(type_spec: Type) -> _ContainedChildrenTypes: - """Returns the types of children `tff.Types` contained by `type_spec`. - - The `_ContainedChildrenTypes` is cached so that this function can be used in - performance sensitive operations. - - Args: - type_spec: A `tff.Type`. - - Raises: - RuntimeError: If the cache becomes corrupted in some unexpected way. - """ - if _contained_children_types_cache is None: - raise RuntimeError('Unexpected runtime error.') - children_types = _contained_children_types_cache.get(type_spec, None) - if children_types is not None: - return children_types - - children_types = _ContainedChildrenTypes() - for child_type in type_spec.children(): - # Create a mutable dict from the frozen `_ContainedChildrenTypes` instance; - # add the child and grandchildren updates; and then evolve the instance. - updates = attrs.asdict(children_types) - if isinstance(child_type, TensorType): - updates['tensor'] = True - elif isinstance(child_type, StructType): - updates['struct'] = True - elif isinstance(child_type, StructWithPythonType): - updates['struct_with_python_type'] = True - elif isinstance(child_type, SequenceType): - updates['sequence'] = True - elif isinstance(child_type, FunctionType): - updates['function'] = True - elif isinstance(child_type, AbstractType): - updates['abstract'] = True - elif isinstance(child_type, PlacementType): - updates['placement'] = True - elif isinstance(child_type, FederatedType): - updates['federated'] = True - else: - raise NotImplementedError(f'Unexpected type found: {type(child_type)}.') - grandchildren_types = _get_contained_children_types(child_type) - for key, value in attrs.asdict(grandchildren_types).items(): - if value: - updates[key] = True - children_types = attrs.evolve(children_types, **updates) - - _contained_children_types_cache[type_spec] = children_types - return children_types - - -def _string_representation(type_spec: Type, formatted: bool) -> str: - """Returns the string representation of a TFF `Type`. - - This function creates a `list` of strings representing the given `type_spec`; - combines the strings in either a formatted or un-formatted representation; and - returns the resulting string representation. - - Args: - type_spec: An instance of a TFF `Type`. - formatted: A boolean indicating if the returned string should be formatted. - - Raises: - TypeError: If `type_spec` has an unexpected type. - """ - - def _combine(components): - """Returns a `list` of strings by combining `components`. - - This function creates and returns a `list` of strings by combining a `list` - of `components`. Each `component` is a `list` of strings representing a part - of the string of a TFF `Type`. The `components` are combined by iteratively - **appending** the last element of the result with the first element of the - `component` and then **extending** the result with remaining elements of the - `component`. - - For example: - - >>> _combine([['a'], ['b'], ['c']]) - ['abc'] - - >>> _combine([['a', 'b', 'c'], ['d', 'e', 'f']]) - ['abcd', 'ef'] - - This function is used to help track where new-lines should be inserted into - the string representation if the lines are formatted. - - Args: - components: A `list` where each element is a `list` of strings - representing a part of the string of a TFF `Type`. - """ - lines = [''] - for component in components: - lines[-1] = '{}{}'.format(lines[-1], component[0]) - lines.extend(component[1:]) - return lines - - def _indent(lines, indent_chars=' '): - """Returns an indented `list` of strings.""" - return ['{}{}'.format(indent_chars, e) for e in lines] - - def _lines_for_named_types(named_type_specs, formatted): - """Returns a `list` of strings representing the given `named_type_specs`. - - Args: - named_type_specs: A `list` of named computations, each being a pair - consisting of a name (either a string, or `None`) and a - `ComputationBuildingBlock`. - formatted: A boolean indicating if the returned string should be - formatted. - """ - lines = [] - for index, (name, type_spec) in enumerate(named_type_specs): - if index != 0: - if formatted: - lines.append([',', '']) - else: - lines.append([',']) - element_lines = _lines_for_type(type_spec, formatted) - if name is not None: - element_lines = _combine([ - ['{}='.format(name)], - element_lines, - ]) - lines.append(element_lines) - return _combine(lines) - - def _lines_for_type(type_spec, formatted): - """Returns a `list` of strings representing the given `type_spec`. - - Args: - type_spec: An instance of a TFF `Type`. - formatted: A boolean indicating if the returned string should be - formatted. - """ - if isinstance(type_spec, AbstractType): - return [type_spec.label] - elif isinstance(type_spec, FederatedType): - member_lines = _lines_for_type(type_spec.member, formatted) - placement_line = '@{}'.format(type_spec.placement) - if type_spec.all_equal: - return _combine([member_lines, [placement_line]]) - else: - return _combine([['{'], member_lines, ['}'], [placement_line]]) - elif isinstance(type_spec, FunctionType): - if type_spec.parameter is not None: - parameter_lines = _lines_for_type(type_spec.parameter, formatted) - else: - parameter_lines = [''] - result_lines = _lines_for_type(type_spec.result, formatted) - return _combine([['('], parameter_lines, [' -> '], result_lines, [')']]) - elif isinstance(type_spec, StructType): - if not type_spec: - return ['<>'] - elements = list(type_spec.items()) - elements_lines = _lines_for_named_types(elements, formatted) - if formatted: - elements_lines = _indent(elements_lines) - lines = [['<', ''], elements_lines, ['', '>']] - else: - lines = [['<'], elements_lines, ['>']] - return _combine(lines) - elif isinstance(type_spec, PlacementType): - return ['placement'] - elif isinstance(type_spec, SequenceType): - element_lines = _lines_for_type(type_spec.element, formatted) - return _combine([element_lines, ['*']]) - elif isinstance(type_spec, TensorType): - if type_spec.shape is None: - return ['{!r}(shape=None)'.format(type_spec.dtype.name)] - elif type_spec.shape: - - def _value_string(value): - return str(value) if value is not None else '?' - - value_strings = [_value_string(e) for e in type_spec.shape] - values_strings = ','.join(value_strings) - return ['{}[{}]'.format(type_spec.dtype.name, values_strings)] - else: - return [type_spec.dtype.name] - else: - raise NotImplementedError( - 'Unexpected type found: {}.'.format(type(type_spec)) - ) - - lines = _lines_for_type(type_spec, formatted) - lines = [line.rstrip() for line in lines] - if formatted: - return '\n'.join(lines) - else: - return ''.join(lines) diff --git a/tensorflow_federated/python/core/impl/types/computation_types_test.py b/tensorflow_federated/python/core/impl/types/computation_types_test.py deleted file mode 100644 index c944e452e3..0000000000 --- a/tensorflow_federated/python/core/impl/types/computation_types_test.py +++ /dev/null @@ -1,1508 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import inspect -from typing import NamedTuple - -from absl.testing import absltest -from absl.testing import parameterized -import attrs -import numpy as np - -from tensorflow_federated.python.common_libs import golden -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements - - -_ALL_INTERNED_TYPES = [ - computation_types.AbstractType, - computation_types.FederatedType, - computation_types.FunctionType, - computation_types.PlacementType, - computation_types.SequenceType, - computation_types.StructType, - computation_types.StructWithPythonType, - computation_types.TensorType, -] - - -@attrs.define -class TestAttrs: - a: int = 1 - a: bool = True - - -class TestNamedTuple(NamedTuple): - a: int = 1 - b: bool = True - - -class TypeMismatchErrorMessageTest(absltest.TestCase): - - def test_short_compact_repr(self): - first = computation_types.TensorType(np.int32) - second = computation_types.TensorType(np.bool_) - actual = computation_types.type_mismatch_error_message( - first, second, computation_types.TypeRelation.EQUIVALENT - ) - golden.check_string('short_compact_repr.expected', actual) - - def test_long_formatted_with_diff(self): - int32 = computation_types.TensorType(np.int32) - first = computation_types.StructType([(None, int32)] * 20) - second = computation_types.StructType([(None, int32)] * 21) - actual = computation_types.type_mismatch_error_message( - first, second, computation_types.TypeRelation.EQUIVALENT - ) - golden.check_string('long_formatted_with_diff.expected', actual) - - def test_container_types_full_repr(self): - first = computation_types.StructWithPythonType([], list) - second = computation_types.StructWithPythonType([], tuple) - actual = computation_types.type_mismatch_error_message( - first, second, computation_types.TypeRelation.EQUIVALENT - ) - golden.check_string('container_types_full_repr.expected', actual) - - -class InternTest(parameterized.TestCase): - - @parameterized.named_parameters( - [(cls.__name__, cls) for cls in _ALL_INTERNED_TYPES] - ) - def test_hashable_from_init_args_has_correct_parameters(self, cls): - hashable_from_init_args_signature = inspect.signature( - cls._hashable_from_init_args - ) - actual_parameters = hashable_from_init_args_signature.parameters - init_signature = inspect.signature(cls.__init__) - # A copy of the parameters is created because `mappingproxy` object does not - # support item deletion. - expected_parameters = init_signature.parameters.copy() - del expected_parameters['self'] - self.assertEqual(actual_parameters, expected_parameters) - - def test_call_raises_type_error_with_unhashable_key(self): - - class Foo(metaclass=computation_types._Intern): # pylint: disable=undefined-variable - - @classmethod - def _hashable_from_init_args(cls, *args, **kwargs): - del args, kwargs # Unused. - return [] - - with self.assertRaises(TypeError): - _ = Foo() - - -class TypeTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'tensor_type_same_dtype_and_shape', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ) - def test_check_equivalent_to_does_not_raise_types_not_equivalent_error( - self, type_spec, other - ): - try: - type_spec.check_equivalent_to(other) - except computation_types.TypesNotEquivalentError: - self.fail('Raised `TypesNotEquivalentError` unexpectedly.') - - @parameterized.named_parameters( - ( - 'tensor_type_different_dtype', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.bool_), - ), - ( - 'tensor_type_different_shape', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32, (10,)), - ), - ) - def test_check_equivalent_to_returns_false(self, type_spec, other): - with self.assertRaises(computation_types.TypesNotEquivalentError): - type_spec.check_equivalent_to(other) - - -class TensorTypeTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'tensor_type', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ( - 'tensor_type_ndims_unknown', - computation_types.TensorType(np.int32, (None,)), - computation_types.TensorType(np.int32, (None,)), - ), - ) - def test_interned(self, type_spec_1, type_spec_2): - self.assertIs(type_spec_1, type_spec_2) - - def test_init_infers_shape(self): - type_spec = computation_types.TensorType(np.int32) - self.assertEqual(type_spec.shape, ()) - - @parameterized.named_parameters( - ( - 'rank_unknown', - computation_types.TensorType(np.int32), - 'int32', - ), - ( - 'ndims_unknown', - computation_types.TensorType(np.int32, (None,)), - 'int32[?]', - ), - ( - 'ndims_10', - computation_types.TensorType(np.int32, (10,)), - 'int32[10]', - ), - ) - def test_str(self, type_spec, expected_str): - actual_str = str(type_spec) - self.assertEqual(actual_str, expected_str) - - @parameterized.named_parameters( - ( - 'rank_unknown', - computation_types.TensorType(np.int32), - 'TensorType(np.int32)', - ), - ( - 'ndims_unknown', - computation_types.TensorType(np.int32, (None,)), - 'TensorType(np.int32, (None,))', - ), - ( - 'ndims_ten', - computation_types.TensorType(np.int32, (10,)), - 'TensorType(np.int32, (10,))', - ), - ) - def test_repr(self, type_spec, expected_repr): - actual_repr = repr(type_spec) - self.assertEqual(actual_repr, expected_repr) - - @parameterized.named_parameters( - ( - 'same_dtype_and_shape', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - True, - ), - ( - 'different_dtype', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.bool_), - False, - ), - ( - 'different_shape', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32, (10,)), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'same_dtype_and_shape', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - True, - ), - ( - 'different_dtype', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.bool_), - False, - ), - ( - 'different_shape', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32, (10,)), - False, - ), - ( - 'ndims_unknown_from_known', - computation_types.TensorType(np.int32, (None,)), - computation_types.TensorType(np.int32, (10,)), - True, - ), - ( - 'ndims_known_from_unknown', - computation_types.TensorType(np.int32, (10,)), - computation_types.TensorType(np.int32, (None,)), - False, - ), - ) - def test_is_assignable_from(self, type_spec, other, expected_result): - actual_result = type_spec.is_assignable_from(other) - self.assertEqual(actual_result, expected_result) - - -class StructTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.StructType([np.int32, np.bool_]) - type_spec_2 = computation_types.StructType([np.int32, np.bool_]) - self.assertIs(type_spec_1, type_spec_2) - - @parameterized.named_parameters( - ('__init__', '__init__'), - ('python_container', 'python_container'), - ('items', 'items'), - ) - def test_init_raises_value_error_with_reserved_name(self, name): - with self.assertRaises(ValueError): - computation_types.StructType([(name, np.int32)]) - - @parameterized.named_parameters( - ( - 'unnamed', - computation_types.StructType([np.int32, np.bool_]), - '', - ), - ( - 'named', - computation_types.StructType([('a', np.int32), ('b', np.bool_)]), - '', - ), - ) - def test_str(self, type_spec, expected_str): - actual_str = str(type_spec) - self.assertEqual(actual_str, expected_str) - - @parameterized.named_parameters( - ( - 'unnamed', - computation_types.StructType([np.int32, np.bool_]), - 'StructType([TensorType(np.int32), TensorType(np.bool_)])', - ), - ( - 'named', - computation_types.StructType([('a', np.int32), ('b', np.bool_)]), - ( - 'StructType([' - "('a', TensorType(np.int32)), " - "('b', TensorType(np.bool_))" - '])' - ) - , - ), - ) - def test_repr(self, type_spec, expected_repr): - actual_repr = repr(type_spec) - self.assertEqual(actual_repr, expected_repr) - - @parameterized.named_parameters( - ( - 'same_elements_unnamed', - computation_types.StructType([np.int32, np.bool_]), - computation_types.StructType([np.int32, np.bool_]), - True, - ), - ( - 'same_elements_named', - computation_types.StructType([('a', np.int32), ('b', np.bool_)]), - computation_types.StructType([('a', np.int32), ('b', np.bool_)]), - True, - ), - ( - 'different_elements_unnamed', - computation_types.StructType([np.int32, np.float64]), - computation_types.StructType([np.int32, np.int32]), - False, - ), - ( - 'different_elements_named', - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - computation_types.StructType([('a', np.int32), ('b', np.int32)]), - False, - ), - ( - 'same_elements_different_names', - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - computation_types.StructType([('a', np.int32), ('c', np.float64)]), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'same_elements_unnamed', - computation_types.StructType([np.int32, np.float64]), - computation_types.StructType([np.int32, np.float64]), - True, - ), - ( - 'same_elements_named', - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - True, - ), - ( - 'different_elements_unnamed', - computation_types.StructType([np.int32, np.float64]), - computation_types.StructType([np.int32, np.int32]), - False, - ), - ( - 'different_elements_named', - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - computation_types.StructType([('a', np.int32), ('b', np.int32)]), - False, - ), - ( - 'same_elements_different_names', - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - computation_types.StructType([('a', np.int32), ('c', np.float64)]), - False, - ), - ( - 'same_elements_unnamed_from_named', - computation_types.StructType([np.int32, np.float64]), - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - False, - ), - ( - 'same_elements_named_from_unnamed', - computation_types.StructType([('a', np.int32), ('b', np.float64)]), - computation_types.StructType([np.int32, np.float64]), - True, - ), - ) - def test_is_assignable_from(self, type_spec, other, expected_result): - actual_result = type_spec.is_assignable_from(other) - self.assertEqual(actual_result, expected_result) - - -class StructWithPythonTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.StructWithPythonType( - [np.int32, np.float64], list - ) - type_spec_2 = computation_types.StructWithPythonType( - [np.int32, np.float64], list - ) - self.assertIs(type_spec_1, type_spec_2) - - @parameterized.named_parameters( - ('__init__', '__init__'), - ('python_container', 'python_container'), - ('items', 'items'), - ) - def test_init_raises_value_error_with_reserved_name(self, name): - with self.assertRaises(ValueError): - computation_types.StructWithPythonType([(name, np.int32)], list) - - @parameterized.named_parameters( - ( - 'list_unnamed', - computation_types.StructWithPythonType([np.int32, np.float64], list), - '', - ), - ( - 'list_named', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], list - ), - '', - ), - ( - 'tuple', - computation_types.StructWithPythonType([np.int32, np.float64], tuple), - '', - ), - ( - 'dict', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], dict - ), - '', - ), - ( - 'ordered_dict', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], collections.OrderedDict - ), - '', - ), - ( - 'attrs', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], TestAttrs - ), - '', - ), - ( - 'named_tuple', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], TestNamedTuple - ), - '', - ), - ) - def test_str(self, type_spec, expected_str): - actual_str = str(type_spec) - self.assertEqual(actual_str, expected_str) - - @parameterized.named_parameters( - ( - 'list_unnamed', - computation_types.StructWithPythonType([np.int32, np.float64], list), - 'StructType([TensorType(np.int32), TensorType(np.float64)]) as list', - ), - ( - 'list_named', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], list - ), - ( - 'StructType([' - "('a', TensorType(np.int32)), " - "('b', TensorType(np.float64))" - ']) as list' - ), - ), - ( - 'tuple', - computation_types.StructWithPythonType([np.int32, np.float64], tuple), - 'StructType([TensorType(np.int32), TensorType(np.float64)]) as tuple', - ), - ( - 'dict', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], dict - ), - ( - 'StructType([' - "('a', TensorType(np.int32)), " - "('b', TensorType(np.float64))" - ']) as dict' - ), - ), - ( - 'ordered_dict', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], collections.OrderedDict - ), - ( - 'StructType([' - "('a', TensorType(np.int32)), " - "('b', TensorType(np.float64))" - ']) as OrderedDict' - ), - ), - ( - 'attrs', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], TestAttrs - ), - ( - 'StructType([' - "('a', TensorType(np.int32)), " - "('b', TensorType(np.float64))" - ']) as TestAttrs' - ), - ), - ( - 'named_tuple', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], TestNamedTuple - ), - ( - 'StructType([' - "('a', TensorType(np.int32)), " - "('b', TensorType(np.float64))" - ']) as TestNamedTuple' - ), - ), - ) - def test_repr(self, type_spec, expected_repr): - actual_repr = repr(type_spec) - self.assertEqual(actual_repr, expected_repr) - - @parameterized.named_parameters( - ( - 'same_elements_and_container_type_unnamed', - computation_types.StructWithPythonType([np.int32, np.float64], list), - computation_types.StructWithPythonType([np.int32, np.float64], list), - True, - ), - ( - 'same_elements_and_container_type_named', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], list - ), - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], list - ), - True, - ), - ( - 'different_elements', - computation_types.StructWithPythonType([np.int32, np.float64], list), - computation_types.StructWithPythonType([np.int32, np.int32], list), - False, - ), - ( - 'different_container_type', - computation_types.StructWithPythonType([np.int32, np.float64], list), - computation_types.StructWithPythonType([np.int32, np.float64], tuple), - False, - ), - ( - 'same_elements_and_container_type_different_names', - computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], list - ), - computation_types.StructWithPythonType( - [('a', np.int32), ('c', np.float64)], list - ), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - -class SequenceTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.SequenceType(np.int32) - type_spec_2 = computation_types.SequenceType(np.int32) - self.assertIs(type_spec_1, type_spec_2) - - def test_init_converts_struct_with_list_to_struct_with_tuple_with_list(self): - type_spec = computation_types.SequenceType( - computation_types.StructWithPythonType([np.int32, np.float64], list) - ) - self.assertIs(type_spec.element.python_container, tuple) - - def test_init_converts_struct_with_list_to_struct_with_tuple_with_list_nested( - self, - ): - type_spec = computation_types.SequenceType( - computation_types.StructWithPythonType( - [ - computation_types.StructWithPythonType( - [np.int32, np.float64], list - ), - computation_types.StructWithPythonType( - [np.int32, np.float64], list - ), - ], - list, - ) - ) - self.assertIs(type_spec.element.python_container, tuple) - first_element, second_element = type_spec.element - self.assertIs(first_element.python_container, tuple) - self.assertIs(second_element.python_container, tuple) - - @parameterized.named_parameters([ - ('abstract_type', computation_types.AbstractType('T')), - ('struct_type', computation_types.StructType([np.int32] * 3)), - ( - 'struct_with_python_type', - computation_types.StructWithPythonType([np.int32] * 3, list), - ), - ('placement_type', computation_types.PlacementType()), - ('tensor_type', computation_types.TensorType(np.int32)), - ]) - def test_init_does_not_raise_value_error(self, element): - try: - computation_types.SequenceType(element) - except ValueError: - self.fail('Raised `ValueError` unexpectedly.') - - @parameterized.named_parameters([ - ( - 'federated_type', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ( - 'function_type', - computation_types.FunctionType(np.int32, np.int32), - ), - ( - 'function_type_nested', - computation_types.StructType([ - computation_types.FunctionType(np.int32, np.int32), - ]), - ), - ( - 'sequence_type', - computation_types.SequenceType([np.int32]), - ), - ]) - def test_init_raises_value_error(self, element): - with self.assertRaises(ValueError): - computation_types.SequenceType(element) - - @parameterized.named_parameters( - ( - 'tensor_type', - computation_types.SequenceType(np.int32), - 'int32*', - ), - ( - 'struct_type', - computation_types.SequenceType( - computation_types.StructType([np.int32, np.float64]) - ), - '*', - ), - ) - def test_str(self, type_spec, expected_str): - actual_str = str(type_spec) - self.assertEqual(actual_str, expected_str) - - @parameterized.named_parameters( - ( - 'tensor_type', - computation_types.SequenceType(np.int32), - 'SequenceType(TensorType(np.int32))', - ), - ( - 'struct_type', - computation_types.SequenceType( - computation_types.StructType([np.int32, np.float64]) - ), - ( - 'SequenceType(StructType([TensorType(np.int32),' - ' TensorType(np.float64)]))' - ), - ), - ) - def test_repr(self, type_spec, expected_repr): - actual_repr = repr(type_spec) - self.assertEqual(actual_repr, expected_repr) - - @parameterized.named_parameters( - ( - 'same_element', - computation_types.SequenceType(np.int32), - computation_types.SequenceType(np.int32), - True, - ), - ( - 'different_element', - computation_types.SequenceType(np.int32), - computation_types.SequenceType(np.float64), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'same_element', - computation_types.SequenceType(np.int32), - computation_types.SequenceType(np.int32), - True, - ), - ( - 'different_element', - computation_types.SequenceType(np.int32), - computation_types.SequenceType(np.float64), - False, - ), - ) - def test_is_assignable_from(self, type_spec, other, expected_result): - actual_result = type_spec.is_assignable_from(other) - self.assertEqual(actual_result, expected_result) - - -class FunctionTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.FunctionType(np.int32, np.int32) - type_spec_2 = computation_types.FunctionType(np.int32, np.int32) - self.assertIs(type_spec_1, type_spec_2) - - @parameterized.named_parameters( - ( - 'with_parameter', - computation_types.FunctionType(np.int32, np.float64), - '(int32 -> float64)', - ), - ( - 'without_parameter', - computation_types.FunctionType(None, np.float64), - '( -> float64)', - ), - ) - def test_str(self, type_spec, expected_str): - actual_str = str(type_spec) - self.assertEqual(actual_str, expected_str) - - @parameterized.named_parameters( - ( - 'with_parameter', - computation_types.FunctionType(np.int32, np.float64), - 'FunctionType(TensorType(np.int32), TensorType(np.float64))', - ), - ( - 'without_parameter', - computation_types.FunctionType(None, np.float64), - 'FunctionType(None, TensorType(np.float64))', - ), - ) - def test_repr(self, type_spec, expected_repr): - actual_repr = repr(type_spec) - self.assertEqual(actual_repr, expected_repr) - - @parameterized.named_parameters( - ( - 'same_parameter_and_result', - computation_types.FunctionType(np.int32, np.float64), - computation_types.FunctionType(np.int32, np.float64), - True, - ), - ( - 'different_parameter', - computation_types.FunctionType(np.int32, np.float64), - computation_types.FunctionType(np.float64, np.float64), - False, - ), - ( - 'different_result', - computation_types.FunctionType(np.int32, np.float64), - computation_types.FunctionType(np.int32, np.int32), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'same_parameter_and_result', - computation_types.FunctionType(np.int32, np.float64), - computation_types.FunctionType(np.int32, np.float64), - True, - ), - ( - 'different_parameter', - computation_types.FunctionType(np.int32, np.float64), - computation_types.FunctionType(np.float64, np.float64), - False, - ), - ( - 'different_result', - computation_types.FunctionType(np.int32, np.float64), - computation_types.FunctionType(np.int32, np.int32), - False, - ), - ) - def test_is_assignable_from(self, type_spec, other, expected_result): - actual_result = type_spec.is_assignable_from(other) - self.assertEqual(actual_result, expected_result) - - -class AbstractTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.AbstractType('T') - type_spec_2 = computation_types.AbstractType('T') - self.assertIs(type_spec_1, type_spec_2) - - def test_str(self): - type_spec = computation_types.AbstractType('T') - actual_str = str(type_spec) - self.assertEqual(actual_str, 'T') - - def test_repr(self): - type_spec = computation_types.AbstractType('T') - actual_str = repr(type_spec) - self.assertEqual(actual_str, "AbstractType('T')") - - @parameterized.named_parameters( - ( - 'same_label', - computation_types.AbstractType('T'), - computation_types.AbstractType('T'), - True, - ), - ( - 'different_label', - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'same_label', - computation_types.AbstractType('T'), - computation_types.AbstractType('T'), - ), - ( - 'different_label', - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ), - ) - def test_is_assignable_from(self, type_spec, other): - with self.assertRaises(TypeError): - type_spec.is_assignable_from(other) - - -class PlacementTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.PlacementType() - type_spec_2 = computation_types.PlacementType() - self.assertIs(type_spec_1, type_spec_2) - - def test_str(self): - type_spec = computation_types.PlacementType() - actual_str = str(type_spec) - self.assertEqual(actual_str, 'placement') - - def test_repr(self): - type_spec = computation_types.PlacementType() - actual_str = repr(type_spec) - self.assertEqual(actual_str, 'PlacementType()') - - @parameterized.named_parameters( - ( - 'placement_type', - computation_types.PlacementType(), - computation_types.PlacementType(), - True, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'placement_type', - computation_types.PlacementType(), - computation_types.PlacementType(), - True, - ), - ) - def test_is_assignable_from(self, type_spec, other, expected_result): - actual_result = type_spec.is_assignable_from(other) - self.assertEqual(actual_result, expected_result) - - -class FederatedTypeTest(parameterized.TestCase): - - def test_interned(self): - type_spec_1 = computation_types.FederatedType(np.int32, placements.CLIENTS) - type_spec_2 = computation_types.FederatedType(np.int32, placements.CLIENTS) - self.assertIs(type_spec_1, type_spec_2) - - @parameterized.named_parameters( - ('clients', placements.CLIENTS, False), - ('server', placements.SERVER, True), - ) - def test_init_infers_all_equal(self, placement, expected_all_equal): - type_spec = computation_types.FederatedType(np.int32, placement) - self.assertEqual(type_spec.all_equal, expected_all_equal) - - @parameterized.named_parameters([ - ('abstract_type', computation_types.AbstractType('T')), - ('placement_type', computation_types.PlacementType()), - ('sequence_type', computation_types.SequenceType([np.int32])), - ('struct_type', computation_types.StructType([np.int32] * 3)), - ( - 'struct_with_python_type', - computation_types.StructWithPythonType([np.int32] * 3, list), - ), - ('tensor_type', computation_types.TensorType(np.int32)), - ]) - def test_init_does_not_raise_value_error(self, member): - try: - computation_types.FederatedType(member, placements.CLIENTS) - except ValueError: - self.fail('Raised `ValueError` unexpectedly.') - - @parameterized.named_parameters([ - ( - 'federated_type', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ( - 'function_type', - computation_types.FunctionType(np.int32, np.int32), - ), - ( - 'function_type_nested', - computation_types.StructType([ - computation_types.FunctionType(np.int32, np.int32), - ]), - ), - ]) - def test_init_raises_value_error(self, member): - with self.assertRaises(ValueError): - computation_types.FederatedType(member, placements.CLIENTS) - - @parameterized.named_parameters( - ( - 'clients_and_all_equal_true', - computation_types.FederatedType(np.int32, placements.CLIENTS, True), - 'int32@CLIENTS', - ), - ( - 'clients_and_all_equal_false', - computation_types.FederatedType(np.int32, placements.CLIENTS, False), - '{int32}@CLIENTS', - ), - ( - 'server_and_all_equal_true', - computation_types.FederatedType(np.int32, placements.SERVER, True), - 'int32@SERVER', - ), - ( - 'server_and_all_equal_false', - computation_types.FederatedType(np.int32, placements.SERVER, False), - '{int32}@SERVER', - ), - ) - def test_str(self, type_spec, expected_str): - actual_str = str(type_spec) - self.assertEqual(actual_str, expected_str) - - @parameterized.named_parameters( - ( - 'clients_and_all_equal_true', - computation_types.FederatedType(np.int32, placements.CLIENTS, True), - ( - "FederatedType(TensorType(np.int32), PlacementLiteral('clients')," - ' True)' - ), - ), - ( - 'clients_and_all_equal_false', - computation_types.FederatedType(np.int32, placements.CLIENTS, False), - ( - "FederatedType(TensorType(np.int32), PlacementLiteral('clients')," - ' False)' - ), - ), - ( - 'server_and_all_equal_true', - computation_types.FederatedType(np.int32, placements.SERVER, True), - ( - "FederatedType(TensorType(np.int32), PlacementLiteral('server')," - ' True)' - ), - ), - ( - 'server_and_all_equal_false', - computation_types.FederatedType(np.int32, placements.SERVER, False), - ( - "FederatedType(TensorType(np.int32), PlacementLiteral('server')," - ' False)' - ), - ), - ) - def test_repr(self, type_spec, expected_repr): - actual_repr = repr(type_spec) - self.assertEqual(actual_repr, expected_repr) - - @parameterized.named_parameters( - ( - 'same_member_and_placement_and_all_equal', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.CLIENTS), - True, - ), - ( - 'different_member', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.float64, placements.CLIENTS), - False, - ), - ( - 'different_placement', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.SERVER), - False, - ), - ( - 'different_all_equals', - computation_types.FederatedType(np.int32, placements.CLIENTS, True), - computation_types.FederatedType(np.int32, placements.CLIENTS, False), - False, - ), - ) - def test_eq(self, type_spec, other, expected_result): - actual_result = type_spec == other - self.assertEqual(actual_result, expected_result) - - @parameterized.named_parameters( - ( - 'same_member_and_placement_and_all_equal', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.CLIENTS), - True, - ), - ( - 'different_member', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.float64, placements.CLIENTS), - False, - ), - ( - 'different_placement', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.SERVER), - False, - ), - ( - 'different_all_equals', - computation_types.FederatedType(np.int32, placements.CLIENTS, True), - computation_types.FederatedType(np.int32, placements.CLIENTS, False), - False, - ), - ) - def test_is_assignable_from(self, type_spec, other, expected_result): - actual_result = type_spec.is_assignable_from(other) - self.assertEqual(actual_result, expected_result) - - -class ToTypeTest(parameterized.TestCase): - - def test_tensor_type(self): - s = computation_types.TensorType(np.int32) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.TensorType) - self.assertEqual(str(t), 'int32') - - def test_tf_type(self): - s = np.int32 - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.TensorType) - self.assertEqual(str(t), 'int32') - - def test_tf_type_and_shape(self): - s = (np.int32, [10]) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.TensorType) - self.assertEqual(str(t), 'int32[10]') - - def test_tf_type_and_shape_with_unknown_dimension(self): - s = (np.int32, [None]) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.TensorType) - self.assertEqual(str(t), 'int32[?]') - - def test_list_of_tf_types(self): - s = [np.int32, np.float64] - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertEqual(str(t), '') - - def test_tuple_of_tf_types(self): - s = (np.int32, np.float64) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, tuple) - self.assertEqual(str(t), '') - - def test_singleton_named_tf_type(self): - s = ('a', np.int32) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, tuple) - self.assertEqual(str(t), '') - - def test_list_of_named_tf_types(self): - s = [('a', np.int32), ('b', np.float64)] - t = computation_types.to_type(s) - # Note: list of pairs should be interpreted as a plain StructType, and - # not try to convert into a python list afterwards. - self.assertNotIsInstance(t, computation_types.StructWithPythonType) - self.assertEqual(str(t), '') - - def test_list_of_partially_named_tf_types(self): - s = [np.float64, ('a', np.int32)] - t = computation_types.to_type(s) - # Note: list of pairs should be interpreted as a plain StructType, and - # not try to convert into a python list afterwards. - self.assertNotIsInstance(t, computation_types.StructWithPythonType) - self.assertEqual(str(t), '') - - def test_ordered_dict_of_tf_types(self): - s = collections.OrderedDict([('a', np.int32), ('b', np.float64)]) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, collections.OrderedDict) - self.assertEqual(str(t), '') - - def test_nested_tuple_of_tf_types(self): - s = (np.int32, (np.float32, np.float64)) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, tuple) - self.assertEqual(str(t), '>') - - def test_nested_tuple_of_named_tf_types(self): - s = (np.int32, (('x', np.float32), np.float64)) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, tuple) - self.assertNotIsInstance(t[1], computation_types.StructWithPythonType) - self.assertEqual(str(t), '>') - - def test_nested_tuple_of_named_nonscalar_tf_types(self): - s = ((np.int32, [1]), (('x', (np.float32, [2])), (np.float64, [3]))) - t = computation_types.to_type(s) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, tuple) - self.assertNotIsInstance(t[1], computation_types.StructWithPythonType) - self.assertEqual(str(t), '>') - - def test_namedtuple_elements_two_tuples(self): - elems = [np.int32 for _ in range(10)] - t = computation_types.to_type(elems) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, list) - for k in structure.iter_elements(t): - self.assertLen(k, 2) - - def test_namedtuples_addressable_by_name(self): - elems = [('item' + str(k), np.int32) for k in range(5)] - t = computation_types.to_type(elems) - # Note: list of pairs should be interpreted as a plain StructType, and - # not try to convert into a python list afterwards. - self.assertNotIsInstance(t, computation_types.StructWithPythonType) - self.assertIsInstance(t.item0, computation_types.TensorType) - self.assertEqual(t.item0, t[0]) - - def test_namedtuple_unpackable(self): - elems = [('item' + str(k), np.int32) for k in range(2)] - t = computation_types.to_type(elems) - a, b = t - self.assertIsInstance(a, computation_types.TensorType) - self.assertIsInstance(b, computation_types.TensorType) - - def test_attrs_instance(self): - - @attrs.define - class TestFoo: - a: object - b: object - - t = computation_types.to_type(TestFoo(a=np.int32, b=(np.float32, [2]))) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, TestFoo) - self.assertEqual(str(t), '') - - def test_nested_attrs_class(self): - - @attrs.define - class TestFoo: - a: object - b: object - - @attrs.define - class TestFoo2: - c: object - - t = computation_types.to_type( - TestFoo(a=[np.int32, np.float64], b=TestFoo2(c=(np.float32, [2]))) - ) - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, TestFoo) - self.assertIsInstance(t.a, computation_types.StructWithPythonType) - self.assertIs(t.a.python_container, list) - self.assertIsInstance(t.b, computation_types.StructWithPythonType) - self.assertIs(t.b.python_container, TestFoo2) - self.assertEqual(str(t), ',b=>') - - def test_struct(self): - t = computation_types.to_type( - structure.Struct(( - (None, np.int32), - ('b', np.int64), - )) - ) - self.assertEqual( - t, - computation_types.StructType([ - (None, computation_types.TensorType(np.int32)), - ('b', computation_types.TensorType(np.int64)), - ]), - ) - - def test_with_np_int32(self): - t = computation_types.to_type(np.int32) - self.assertIsInstance(t, computation_types.TensorType) - self.assertEqual(t.dtype, np.int32) - self.assertEqual(t.shape, ()) - - def test_with_np_int32_in_tensor_spec(self): - t = computation_types.to_type((np.int32, [5])) - self.assertIsInstance(t, computation_types.TensorType) - self.assertEqual(t.dtype, np.int32) - self.assertEqual(t.shape, (5,)) - - def test_with_np_int32_in_dict(self): - t = computation_types.to_type(collections.OrderedDict([('foo', np.int32)])) - self.assertIsInstance(t, computation_types.StructType) - self.assertIsInstance(t.foo, computation_types.TensorType) - self.assertEqual(t.foo.dtype, np.int32) - self.assertEqual(t.foo.shape, ()) - - @parameterized.named_parameters( - ('none', None), - ('object', object()), - ) - def test_raises_type_error(self, obj): - with self.assertRaises(TypeError): - _ = computation_types.to_type(obj) - - -class RepresentationTest(absltest.TestCase): - - def test_returns_string_for_abstract_type(self): - type_spec = computation_types.AbstractType('T') - - self.assertEqual(type_spec.compact_representation(), 'T') - self.assertEqual(type_spec.formatted_representation(), 'T') - - def test_returns_string_for_federated_type_clients(self): - type_spec = computation_types.FederatedType(np.int32, placements.CLIENTS) - - self.assertEqual(type_spec.compact_representation(), '{int32}@CLIENTS') - self.assertEqual(type_spec.formatted_representation(), '{int32}@CLIENTS') - - def test_returns_string_for_federated_type_server(self): - type_spec = computation_types.FederatedType(np.int32, placements.SERVER) - - self.assertEqual(type_spec.compact_representation(), 'int32@SERVER') - self.assertEqual(type_spec.formatted_representation(), 'int32@SERVER') - - def test_returns_string_for_function_type(self): - type_spec = computation_types.FunctionType(np.int32, np.float32) - - self.assertEqual(type_spec.compact_representation(), '(int32 -> float32)') - self.assertEqual(type_spec.formatted_representation(), '(int32 -> float32)') - - def test_returns_string_for_function_type_with_named_tuple_type_parameter( - self, - ): - parameter = computation_types.StructType((np.int32, np.float32)) - type_spec = computation_types.FunctionType(parameter, np.float64) - - self.assertEqual( - type_spec.compact_representation(), '( -> float64)' - ) - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '(<\n' - ' int32,\n' - ' float32\n' - '> -> float64)' - ) - # pyformat: enable - - def test_returns_string_for_function_type_with_named_tuple_type_result(self): - result = computation_types.StructType((np.int32, np.float32)) - type_spec = computation_types.FunctionType(np.float64, result) - - self.assertEqual( - type_spec.compact_representation(), '(float64 -> )' - ) - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '(float64 -> <\n' - ' int32,\n' - ' float32\n' - '>)' - ) - # pyformat: enable - - def test_returns_string_for_function_type_with_named_tuple_type_parameter_and_result( - self, - ): - parameter = computation_types.StructType((np.int32, np.float32)) - result = computation_types.StructType((np.float64, np.str_)) - type_spec = computation_types.FunctionType(parameter, result) - - self.assertEqual( - type_spec.compact_representation(), '( -> )' - ) - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '(<\n' - ' int32,\n' - ' float32\n' - '> -> <\n' - ' float64,\n' - ' str\n' - '>)' - ) - # pyformat: enable - - def test_returns_string_for_named_tuple_type_unnamed(self): - type_spec = computation_types.StructType((np.int32, np.float32)) - - self.assertEqual(type_spec.compact_representation(), '') - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '<\n' - ' int32,\n' - ' float32\n' - '>' - ) - # pyformat: enable - - def test_returns_string_for_named_tuple_type_named(self): - type_spec = computation_types.StructType( - (('a', np.int32), ('b', np.float32)) - ) - - self.assertEqual(type_spec.compact_representation(), '') - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '<\n' - ' a=int32,\n' - ' b=float32\n' - '>' - ) - # pyformat: enable - - def test_returns_string_for_named_tuple_type_nested(self): - type_spec_1 = computation_types.StructType((np.int32, np.float32)) - type_spec_2 = computation_types.StructType((type_spec_1, np.bool_)) - type_spec_3 = computation_types.StructType((type_spec_2, np.str_)) - type_spec = type_spec_3 - - self.assertEqual( - type_spec.compact_representation(), '<<,bool>,str>' - ) - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '<\n' - ' <\n' - ' <\n' - ' int32,\n' - ' float32\n' - ' >,\n' - ' bool\n' - ' >,\n' - ' str\n' - '>' - ) - # pyformat: enable - - def test_returns_string_for_named_tuple_type_with_one_element(self): - type_spec = computation_types.StructType((np.int32,)) - - self.assertEqual(type_spec.compact_representation(), '') - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '<\n' - ' int32\n' - '>' - ) - # pyformat: enable - - def test_returns_string_for_named_tuple_type_with_no_element(self): - type_spec = computation_types.StructType([]) - - self.assertEqual(type_spec.compact_representation(), '<>') - self.assertEqual(type_spec.formatted_representation(), '<>') - - def test_returns_string_for_placement_type(self): - type_spec = computation_types.PlacementType() - - self.assertEqual(type_spec.compact_representation(), 'placement') - self.assertEqual(type_spec.formatted_representation(), 'placement') - - def test_returns_string_for_sequence_type_int(self): - type_spec = computation_types.SequenceType(np.int32) - - self.assertEqual(type_spec.compact_representation(), 'int32*') - self.assertEqual(type_spec.formatted_representation(), 'int32*') - - def test_returns_string_for_sequence_type_float(self): - type_spec = computation_types.SequenceType(np.float32) - - self.assertEqual(type_spec.compact_representation(), 'float32*') - self.assertEqual(type_spec.formatted_representation(), 'float32*') - - def test_returns_string_for_sequence_type_named_tuple_type(self): - element = computation_types.StructType((np.int32, np.float32)) - type_spec = computation_types.SequenceType(element) - - self.assertEqual(type_spec.compact_representation(), '*') - # pyformat: disable - self.assertEqual( - type_spec.formatted_representation(), - '<\n' - ' int32,\n' - ' float32\n' - '>*' - ) - # pyformat: enable - - def test_returns_string_for_tensor_type_int(self): - type_spec = computation_types.TensorType(np.int32) - - self.assertEqual(type_spec.compact_representation(), 'int32') - self.assertEqual(type_spec.formatted_representation(), 'int32') - - def test_returns_string_for_tensor_type_float(self): - type_spec = computation_types.TensorType(np.float32) - - self.assertEqual(type_spec.compact_representation(), 'float32') - self.assertEqual(type_spec.formatted_representation(), 'float32') - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/container_types_full_repr.expected b/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/container_types_full_repr.expected deleted file mode 100644 index 39b61ff2b1..0000000000 --- a/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/container_types_full_repr.expected +++ /dev/null @@ -1 +0,0 @@ -Type `StructType([]) as list` is not equivalent to type `StructType([]) as tuple` diff --git a/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/long_formatted_with_diff.expected b/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/long_formatted_with_diff.expected deleted file mode 100644 index faba13ea60..0000000000 --- a/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/long_formatted_with_diff.expected +++ /dev/null @@ -1,60 +0,0 @@ -Type -`< - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32 ->` -is not equivalent to type -`< - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32, - int32 ->` - -Diff: ---- - -+++ - -@@ -1,4 +1,5 @@ - - < -+ int32, - int32, - int32, - int32, diff --git a/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/short_compact_repr.expected b/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/short_compact_repr.expected deleted file mode 100644 index 25af57431e..0000000000 --- a/tensorflow_federated/python/core/impl/types/computation_types_test_goldens/short_compact_repr.expected +++ /dev/null @@ -1 +0,0 @@ -Type `int32` is not equivalent to type `bool` diff --git a/tensorflow_federated/python/core/impl/types/dtype_utils.py b/tensorflow_federated/python/core/impl/types/dtype_utils.py deleted file mode 100644 index 3e34430aba..0000000000 --- a/tensorflow_federated/python/core/impl/types/dtype_utils.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2024, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# limitations under the License. -"""Utilities for working with dtypes.""" - -from collections.abc import Mapping -from typing import Union -import warnings - -import ml_dtypes -import numpy as np - -from tensorflow_federated.proto.v0 import data_type_pb2 - - -# Mapping from `DataType` to `type[np.generic]`. -_PROTO_TO_DTYPE: Mapping[data_type_pb2.DataType, type[np.generic]] = { - data_type_pb2.DataType.DT_BOOL: np.bool_, - data_type_pb2.DataType.DT_INT8: np.int8, - data_type_pb2.DataType.DT_INT16: np.int16, - data_type_pb2.DataType.DT_INT32: np.int32, - data_type_pb2.DataType.DT_INT64: np.int64, - data_type_pb2.DataType.DT_UINT8: np.uint8, - data_type_pb2.DataType.DT_UINT16: np.uint16, - data_type_pb2.DataType.DT_UINT32: np.uint32, - data_type_pb2.DataType.DT_UINT64: np.uint64, - data_type_pb2.DataType.DT_HALF: np.float16, - data_type_pb2.DataType.DT_FLOAT: np.float32, - data_type_pb2.DataType.DT_DOUBLE: np.float64, - data_type_pb2.DataType.DT_COMPLEX64: np.complex64, - data_type_pb2.DataType.DT_COMPLEX128: np.complex128, - data_type_pb2.DataType.DT_BFLOAT16: ml_dtypes.bfloat16, - data_type_pb2.DataType.DT_STRING: np.str_, -} - - -def from_proto( - dtype_pb: data_type_pb2.DataType, -) -> type[np.generic]: - """Returns a `type[np.generic]` for the `dtype_pb`.""" - if dtype_pb in _PROTO_TO_DTYPE: - return _PROTO_TO_DTYPE[dtype_pb] - else: - raise NotImplementedError(f'Unexpected dtype found: {dtype_pb}.') - - -# Mapping from `type[np.generic]` to `DataType`. -_DTYPE_TO_PROTO: Mapping[type[np.generic], data_type_pb2.DataType] = { - np.bool_: data_type_pb2.DataType.DT_BOOL, - np.int8: data_type_pb2.DataType.DT_INT8, - np.int16: data_type_pb2.DataType.DT_INT16, - np.int32: data_type_pb2.DataType.DT_INT32, - np.int64: data_type_pb2.DataType.DT_INT64, - np.uint8: data_type_pb2.DataType.DT_UINT8, - np.uint16: data_type_pb2.DataType.DT_UINT16, - np.uint32: data_type_pb2.DataType.DT_UINT32, - np.uint64: data_type_pb2.DataType.DT_UINT64, - np.float16: data_type_pb2.DataType.DT_HALF, - np.float32: data_type_pb2.DataType.DT_FLOAT, - np.float64: data_type_pb2.DataType.DT_DOUBLE, - np.complex64: data_type_pb2.DataType.DT_COMPLEX64, - np.complex128: data_type_pb2.DataType.DT_COMPLEX128, - ml_dtypes.bfloat16: data_type_pb2.DataType.DT_BFLOAT16, - np.str_: data_type_pb2.DataType.DT_STRING, -} - - -def to_proto(dtype: type[np.generic]) -> data_type_pb2.DataType: - """Returns a `DataType` for the `dtype`.""" - if dtype in _DTYPE_TO_PROTO: - return _DTYPE_TO_PROTO[dtype] - else: - raise NotImplementedError(f'Unexpected dtype found: {dtype}.') - - -def is_valid_dtype(dtype: type[np.generic]) -> bool: - """Returns `True` if `dtype` is valid, otherwise `False`.""" - return dtype in _DTYPE_TO_PROTO - - -def can_cast( - value: Union[bool, int, float, complex, str, bytes], - dtype: type[np.generic], -) -> bool: - """Returns `True` if `value` can be cast to `dtype`, otherwise `False`. - - This function is intended to be used to determine if the size of the `dtype` - is capable of holding the `value`. This is useful, for example, when trying to - infer the dtype of the `value`. This function is not intended to be used to - determine if you **should** cast a the `value` to `dtype`. - - Args: - value: The value to check. - dtype: The dtype to check against. - """ - - # `np.can_cast` does not support Python scalars (since version 2.0). Casting - # the value to a numpy value and testing for an overflow is equivalent to - # testing the Python value. - numpy_version = tuple(int(x) for x in np.__version__.split('.')) - if numpy_version >= (2, 0): - # When encountering an overflow, numpy issues a `RuntimeWarning` for - # floating dtypes and raises an `OverflowError` for integer dtypes. - with warnings.catch_warnings(action='error', category=RuntimeWarning): - try: - np.asarray(value, dtype=dtype) - return True - except (OverflowError, RuntimeWarning): - return False - else: - return np.can_cast(value, dtype) - - -def infer_dtype( - obj: Union[bool, int, float, complex, str, bytes], -) -> type[np.generic]: - """Returns a scalar numpy dtype for a Python scalar. - - Args: - obj: A Python scalar. - - Returns: - A scalar numpy dtype. - """ - if isinstance(obj, bool): - return np.bool_ - elif isinstance(obj, int): - if can_cast(obj, np.int32): - return np.int32 - elif can_cast(obj, np.int64): - return np.int64 - else: - raise ValueError( - 'Expected `obj` to be an `int` in the range' - f' [{np.iinfo(np.int64).min}, {np.iinfo(np.int64).max}],' - f' found: {obj}.' - ) - elif isinstance(obj, float): - return np.float32 - elif isinstance(obj, complex): - return np.complex128 - elif isinstance(obj, (str, bytes)): - return np.str_ - else: - raise NotImplementedError(f'Unexpected type found: {type(obj)}.') diff --git a/tensorflow_federated/python/core/impl/types/dtype_utils_test.py b/tensorflow_federated/python/core/impl/types/dtype_utils_test.py deleted file mode 100644 index 284f1516d5..0000000000 --- a/tensorflow_federated/python/core/impl/types/dtype_utils_test.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2024, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -import ml_dtypes -import numpy as np - -from tensorflow_federated.python.core.impl.types import dtype_utils - - -class DtypeUtilsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('none', None), - ('object', object()), - ) - def test_from_proto_raises_not_implemented_error(self, dtype_pb): - with self.assertRaises(NotImplementedError): - dtype_utils.from_proto(dtype_pb) - - @parameterized.named_parameters( - ('none', None), - ('object', object()), - ) - def test_to_proto_raises_not_implemented_error(self, dtype): - with self.assertRaises(NotImplementedError): - dtype_utils.to_proto(dtype) - - @parameterized.named_parameters( - ('bool', np.bool_), - ('int8', np.int8), - ('int16', np.int16), - ('int32', np.int32), - ('int64', np.int64), - ('uint8', np.uint8), - ('uint16', np.uint16), - ('uint32', np.uint32), - ('uint64', np.uint64), - ('float16', np.float16), - ('float32', np.float32), - ('float64', np.float64), - ('complex64', np.complex64), - ('complex128', np.complex128), - ('bfloat16', ml_dtypes.bfloat16), - ('str', np.str_), - ) - def test_is_valid_dtype_returns_true(self, dtype): - self.assertTrue(dtype_utils.is_valid_dtype(dtype)) - - @parameterized.named_parameters( - ('bytes', np.bytes_), - ('object', np.object_), - ) - def test_is_valid_dtype_returns_false(self, dtype): - self.assertFalse(dtype_utils.is_valid_dtype(dtype)) - - @parameterized.named_parameters( - ('bool', True, np.bool_), - ('int8', 1, np.int8), - ('int16', 1, np.int16), - ('int32', 1, np.int32), - ('int64', 1, np.int64), - ('uint8', 1, np.uint8), - ('uint16', 1, np.uint16), - ('uint32', 1, np.uint32), - ('uint64', 1, np.uint64), - ('float16', 1.0, np.float16), - ('float32', 1.0, np.float32), - ('float64', 1.0, np.float64), - ('complex64', complex(1.0, 1.0), np.complex64), - ('complex128', complex(1.0, 1.0), np.complex128), - ('str', 'a', np.str_), - ('bytes', b'a', np.bytes_), - ) - def test_can_cast_returns_true(self, value, dtype): - result = dtype_utils.can_cast(value, dtype) - self.assertTrue(result) - - @parameterized.named_parameters( - ('int64', np.iinfo(np.int64).max, np.int32), - ('float64', float(np.finfo(np.float64).max), np.float32), - ('complex64', complex(np.finfo(np.float64).max, 1), np.complex64), - ) - def test_can_cast_returns_false(self, value, dtype): - result = dtype_utils.can_cast(value, dtype) - self.assertFalse(result) - - @parameterized.named_parameters( - ('bool', True, np.bool_), - ('int32', 1, np.int32), - ('int32_min', int(np.iinfo(np.int32).min), np.int32), - ('int32_max', int(np.iinfo(np.int32).max), np.int32), - ('int64_min', int(np.iinfo(np.int64).min), np.int64), - ('int64_max', int(np.iinfo(np.int64).max), np.int64), - ('float', 1.0, np.float32), - ('complex', complex(1.0, 1.0), np.complex128), - ('str', 'a', np.str_), - ('bytes', b'a', np.str_), - ) - def test_infer_dtype(self, obj, expected_value): - actual_value = dtype_utils.infer_dtype(obj) - self.assertEqual(actual_value, expected_value) - - @parameterized.named_parameters( - ('int_min', int(np.iinfo(np.int64).min) - 1), - ('int_max', int(np.iinfo(np.int64).max) + 1), - ) - def test_infer_dtype_raises_value_error(self, obj): - with self.assertRaises(ValueError): - dtype_utils.infer_dtype(obj) - - @parameterized.named_parameters( - ('none', None), - ('object', object()), - ) - def test_infer_dtype_raises_not_implemented_error(self, obj): - with self.assertRaises(NotImplementedError): - dtype_utils.infer_dtype(obj) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/placements.py b/tensorflow_federated/python/core/impl/types/placements.py deleted file mode 100644 index fe8b4ee6bc..0000000000 --- a/tensorflow_federated/python/core/impl/types/placements.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Placement literals for use in computation and type definitions.""" - - -class PlacementLiteral: - """A representation of one of the globally recognized placement literals.""" - - def __init__( - self, name: str, uri: str, default_all_equal: bool, description: str - ): - self._name = name - self._uri = uri - self._default_all_equal = default_all_equal - self._description = description - - @property - def name(self) -> str: - return self._name - - @property - def uri(self) -> str: - return self._uri - - @property - def default_all_equal(self) -> bool: - return self._default_all_equal - - def is_server(self) -> bool: - return self is SERVER - - def is_clients(self) -> bool: - return self is CLIENTS - - def __doc__(self) -> str: - return self._description - - def __str__(self) -> str: - return self._name - - def __repr__(self) -> str: - return "PlacementLiteral('{}')".format(self._uri) - - def __eq__(self, other: object) -> bool: - if self is other: - return True - elif not isinstance(other, PlacementLiteral): - return NotImplemented - return self._uri == other.uri - - def __hash__(self) -> int: - return hash(self._uri) - - -# TODO: b/113112108 - Define the remaining placement literals (e.g., -# intermediate coordinators). Possibly rename SERVER to COORDINATOR or some such -# if desired. - -CLIENTS = PlacementLiteral( - 'CLIENTS', - 'clients', - default_all_equal=False, - description='The collective of all client devices.', -) - -SERVER = PlacementLiteral( - 'SERVER', - 'server', - default_all_equal=True, - description='The single top-level central coordinator.', -) - - -def uri_to_placement_literal(uri: str) -> PlacementLiteral: - """Returns the placement literal that corresponds to the given URI. - - Args: - uri: The URI of the placement. - - Returns: - The placement literal. - - Raises: - ValueError: if there is no known placement literal with such URI. - """ - for literal in [CLIENTS, SERVER]: - if uri == literal.uri: - return literal - raise ValueError(f'There is no known literal with uri `{uri}`.') diff --git a/tensorflow_federated/python/core/impl/types/placements_test.py b/tensorflow_federated/python/core/impl/types/placements_test.py deleted file mode 100644 index 18cb1626b2..0000000000 --- a/tensorflow_federated/python/core/impl/types/placements_test.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.types import placements - - -class PlacementLiteralsTest(absltest.TestCase): - - def test_something(self): - self.assertNotEqual(str(placements.CLIENTS), str(placements.SERVER)) - for literal in [placements.CLIENTS, placements.SERVER]: - self.assertIs(placements.uri_to_placement_literal(literal.uri), literal) - - def test_comparators_and_hashing(self): - self.assertEqual(placements.CLIENTS, placements.CLIENTS) - self.assertNotEqual(placements.CLIENTS, placements.SERVER) - self.assertEqual(hash(placements.CLIENTS), hash(placements.CLIENTS)) - self.assertNotEqual(hash(placements.CLIENTS), hash(placements.SERVER)) - foo = {placements.CLIENTS: 10, placements.SERVER: 20} - self.assertEqual(foo[placements.CLIENTS], 10) - self.assertEqual(foo[placements.SERVER], 20) - - def test_comparison_to_none(self): - self.assertNotEqual(placements.CLIENTS, None) - self.assertNotEqual(placements.SERVER, None) - self.assertNotEqual(None, placements.CLIENTS) - self.assertNotEqual(None, placements.SERVER) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/type_analysis.py b/tensorflow_federated/python/core/impl/types/type_analysis.py deleted file mode 100644 index 5470adadaa..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_analysis.py +++ /dev/null @@ -1,882 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of static analysis functions for computation types.""" - -import collections -from collections.abc import Callable -from typing import Optional - -import ml_dtypes -import numpy as np - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import type_transformations - -_TypePredicate = Callable[[computation_types.Type], bool] - - -def preorder_types(type_signature: computation_types.Type): - """Yields each type in `type_signature` in a preorder fashion.""" - yield type_signature - for child in type_signature.children(): - yield from preorder_types(child) - - -def count( - type_signature: computation_types.Type, predicate: _TypePredicate -) -> int: - """Returns the number of types in `type_signature` matching `predicate`. - - Args: - type_signature: A tree of `computation_type.Type`s to count. - predicate: A Python function that takes a type as a parameter and returns a - boolean value. - """ - one_or_zero = lambda t: 1 if predicate(t) else 0 - return sum(map(one_or_zero, preorder_types(type_signature))) - - -def contains( - type_signature: computation_types.Type, predicate: _TypePredicate -) -> bool: - """Checks if `type_signature` contains any types that pass `predicate`.""" - for t in preorder_types(type_signature): - if predicate(t): - return True - return False - - -def contains_federated_types(type_signature): - """Returns whether or not `type_signature` contains a federated type.""" - return contains( - type_signature, lambda t: isinstance(t, computation_types.FederatedType) - ) - - -def contains_tensor_types(type_signature): - """Returns whether or not `type_signature` contains a tensor type.""" - return contains( - type_signature, lambda t: isinstance(t, computation_types.TensorType) - ) - - -def contains_only( - type_signature: computation_types.Type, - predicate: _TypePredicate, -) -> bool: - """Checks if `type_signature` contains only types that pass `predicate`.""" - return not contains(type_signature, lambda t: not predicate(t)) - - -def check_type(value: object, type_spec: computation_types.Type): - """Checks whether `val` is of TFF type `type_spec`. - - Args: - value: The object to check. - type_spec: A `computation_types.Type`, the type that `value` is checked - against. - - Raises: - TypeError: If the inferred type of `value` is not assignable to `type_spec`. - """ - py_typecheck.check_type(type_spec, computation_types.Type) - value_type = type_conversions.infer_type(value) - if not type_spec.is_assignable_from(value_type): - raise TypeError( - computation_types.type_mismatch_error_message( - value_type, - type_spec, - computation_types.TypeRelation.ASSIGNABLE, - second_is_expected=True, - ) - ) - - -def is_tensorflow_compatible_type(type_spec): - """Checks `type_spec` against an explicit list of `tf_computation`.""" - if type_spec is None: - return True - - def _predicate(type_spec: computation_types.Type) -> bool: - return isinstance( - type_spec, - ( - computation_types.SequenceType, - computation_types.StructType, - computation_types.TensorType, - ), - ) - - return contains_only(type_spec, _predicate) - - -def is_structure_of_tensors(type_spec: computation_types.Type) -> bool: - def _predicate(type_spec: computation_types.Type) -> bool: - return isinstance( - type_spec, - ( - computation_types.StructType, - computation_types.TensorType, - ), - ) - - return contains_only(type_spec, _predicate) - - -def check_tensorflow_compatible_type(type_spec): - if not is_tensorflow_compatible_type(type_spec): - raise TypeError( - 'Expected type to be compatible with TensorFlow (i.e. tensor, ' - 'sequence, or tuple types), found {}.'.format(type_spec) - ) - - -def is_generic_op_compatible_type(type_spec): - """Checks `type_spec` against an explicit list of generic operators.""" - if type_spec is None: - return False - - def _predicate(type_spec: computation_types.Type) -> bool: - return isinstance( - type_spec, - ( - computation_types.TensorType, - computation_types.StructType, - ), - ) - - return contains_only(type_spec, _predicate) - - -def is_binary_op_with_upcast_compatible_pair( - possibly_nested_type: Optional[computation_types.Type], - type_to_upcast: computation_types.Type, -) -> bool: - """Checks unambiguity in applying `type_to_upcast` to `possibly_nested_type`. - - That is, checks that either these types are equivalent and contain only - tuples and tensors, or that - `possibly_nested_type` is perhaps a nested structure containing only tensors - with `dtype` of `type_to_upcast` at the leaves, where `type_to_upcast` must - be a scalar tensor type. Notice that this relationship is not symmetric, - since binary operators need not respect this symmetry in general. - For example, it makes perfect sence to divide a nested structure of tensors - by a scalar, but not the other way around. - - Args: - possibly_nested_type: A `computation_types.Type`, or `None`. - type_to_upcast: A `computation_types.Type`, or `None`. - - Returns: - Boolean indicating whether `type_to_upcast` can be upcast to - `possibly_nested_type` in the manner described above. - """ - if possibly_nested_type is not None: - py_typecheck.check_type(possibly_nested_type, computation_types.Type) - if type_to_upcast is not None: - py_typecheck.check_type(type_to_upcast, computation_types.Type) - if not ( - is_generic_op_compatible_type(possibly_nested_type) - and is_generic_op_compatible_type(type_to_upcast) - ): - return False - if possibly_nested_type is None: - return type_to_upcast is None - if possibly_nested_type.is_equivalent_to(type_to_upcast): - return True - if not isinstance( - type_to_upcast, computation_types.TensorType - ) or not array_shape.is_shape_scalar(type_to_upcast.shape): - return False - - types_are_ok = [True] - - only_allowed_dtype = type_to_upcast.dtype # pytype: disable=attribute-error - - def _check_tensor_types(type_spec): - if ( - isinstance(type_spec, computation_types.TensorType) - and type_spec.dtype != only_allowed_dtype - ): # pytype: disable=attribute-error - types_are_ok[0] = False - return type_spec, False - - type_transformations.transform_type_postorder( - possibly_nested_type, _check_tensor_types - ) - - return types_are_ok[0] - - -def check_all_abstract_types_are_bound(type_spec): - """Checks that all abstract types labels appearing in 'type_spec' are bound. - - For abstract types to be bound, it means that type labels appearing on the - result side of functional type signatures must also appear on the parameter - side. This check is intended to verify that abstract types are only used to - model template-like type signatures, and can always be reduce to a concrete - type by specializing templates to work with specific sets of arguments. - - Examples of valid types that pass this check successfully: - - int32 - (int32 -> int32) - ( -> int32) - (T -> T) - ((T -> T) -> bool) - (( -> T) -> T) - ( T)> -> T) - (T* -> int32) - ( -> (T -> T)) - U), U> -> - - Examples of invalid types that fail this check because 'T' is unbound: - - T - (int32 -> T) - ( -> T) - (T -> U) - - Args: - type_spec: An instance of computation_types.Type, or something convertible - to it. - - Raises: - TypeError: if arguments are of the wrong types, or if unbound type labels - occur in 'type_spec'. - """ - - def _check_or_get_unbound_abstract_type_labels( - type_spec, bound_labels, check - ): - """Checks or collects abstract type labels from 'type_spec'. - - This is a helper function used by 'check_abstract_types_are_bound', not to - be exported out of this module. - - Args: - type_spec: An instance of computation_types.Type. - bound_labels: A set of string labels that refer to 'bound' abstract types, - i.e., ones that appear on the parameter side of a functional type. - check: A bool value. If True, no new unbound type labels are permitted, - and if False, any new labels encountered are returned as a set. - - Returns: - If check is False, a set of new abstract type labels introduced in - 'type_spec' that don't yet appear in the set 'bound_labels'. If check is - True, always returns an empty set. - - Raises: - TypeError: if unbound labels are found and check is True. - """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): - return set() - elif isinstance(type_spec, computation_types.SequenceType): - return _check_or_get_unbound_abstract_type_labels( - type_spec.element, bound_labels, check - ) - elif isinstance(type_spec, computation_types.FederatedType): - return _check_or_get_unbound_abstract_type_labels( - type_spec.member, bound_labels, check - ) - elif isinstance(type_spec, computation_types.StructType): - return set().union(*[ - _check_or_get_unbound_abstract_type_labels(v, bound_labels, check) - for _, v in type_spec.items() - ]) - elif isinstance(type_spec, computation_types.AbstractType): - if type_spec.label in bound_labels: - return set() - elif not check: - return set([type_spec.label]) - else: - raise TypeError("Unbound type label '{}'.".format(type_spec.label)) - elif isinstance(type_spec, computation_types.FunctionType): - if type_spec.parameter is None: - parameter_labels = set() - else: - parameter_labels = _check_or_get_unbound_abstract_type_labels( - type_spec.parameter, bound_labels, False - ) - result_labels = _check_or_get_unbound_abstract_type_labels( - type_spec.result, bound_labels.union(parameter_labels), check - ) - return parameter_labels.union(result_labels) - - _check_or_get_unbound_abstract_type_labels(type_spec, set(), True) - - -class SumIncompatibleError(TypeError): - - def __init__(self, type_spec, type_spec_context, reason): - message = ( - 'Expected a type which is compatible with the sum operator, found\n' - f'{type_spec_context}\nwhich contains\n{type_spec}\nwhich is not ' - f'sum-compatible because {reason}.' - ) - super().__init__(message) - - -def check_is_sum_compatible(type_spec, type_spec_context=None): - """Determines if `type_spec` is a type that can be added to itself. - - Types that are sum-compatible are composed of scalars of numeric types, - possibly packaged into nested named tuples, and possibly federated. Types - that are sum-incompatible include sequences, functions, abstract types, - and placements. - - Args: - type_spec: A `computation_types.Type`. - type_spec_context: An optional parent type to include in the error message. - - Raises: - SumIncompatibleError: if `type_spec` is not sum-compatible. - """ - py_typecheck.check_type(type_spec, computation_types.Type) - if type_spec_context is None: - type_spec_context = type_spec - py_typecheck.check_type(type_spec_context, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): - if not ( - np.issubdtype(type_spec.dtype, np.number) - or type_spec.dtype == ml_dtypes.bfloat16 - ): - raise SumIncompatibleError( - type_spec, type_spec_context, f'{type_spec.dtype} is not numeric' - ) - if not array_shape.is_shape_fully_defined(type_spec.shape): - raise SumIncompatibleError( - type_spec, - type_spec_context, - f'{type_spec.shape} is not fully defined', - ) - elif isinstance(type_spec, computation_types.StructType): - for _, element_type in type_spec.items(): - check_is_sum_compatible(element_type, type_spec_context) - elif isinstance(type_spec, computation_types.FederatedType): - check_is_sum_compatible(type_spec.member, type_spec_context) - else: - raise SumIncompatibleError( - type_spec, - type_spec_context, - 'only structures of tensors (possibly federated) may be summed', - ) - - -def is_structure_of_floats(type_spec: computation_types.Type) -> bool: - """Determines if `type_spec` is a structure of floats. - - Note that an empty `computation_types.StructType` will return `True`, as it - does not contain any non-floating types. - - Args: - type_spec: A `computation_types.Type`. - - Returns: - `True` iff `type_spec` is a structure of floats, otherwise `False`. - """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): - return np.issubdtype(type_spec.dtype, np.floating) - elif isinstance(type_spec, computation_types.StructType): - return all(is_structure_of_floats(v) for _, v in type_spec.items()) - elif isinstance(type_spec, computation_types.FederatedType): - return is_structure_of_floats(type_spec.member) - else: - return False - - -def check_is_structure_of_floats(type_spec): - if not is_structure_of_floats(type_spec): - raise TypeError( - 'Expected a type which is structure of floats, found {}.'.format( - type_spec - ) - ) - - -def is_structure_of_integers(type_spec: computation_types.Type) -> bool: - """Determines if `type_spec` is a structure of integers. - - Note that an empty `computation_types.StructType` will return `True`, as it - does not contain any non-integer types. - - Args: - type_spec: A `computation_types.Type`. - - Returns: - `True` iff `type_spec` is a structure of integers, otherwise `False`. - """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): - return np.issubdtype(type_spec.dtype, np.integer) - elif isinstance(type_spec, computation_types.StructType): - return all(is_structure_of_integers(v) for _, v in type_spec.items()) - elif isinstance(type_spec, computation_types.FederatedType): - return is_structure_of_integers(type_spec.member) - else: - return False - - -def check_is_structure_of_integers(type_spec): - if not is_structure_of_integers(type_spec): - raise TypeError( - 'Expected a type which is structure of integers, found {}.'.format( - type_spec - ) - ) - - -def is_single_integer_or_matches_structure( - type_sig: computation_types.Type, shape_type: computation_types.Type -) -> bool: - """If `type_sig` is an integer or integer structure matching `shape_type`.""" - - py_typecheck.check_type(type_sig, computation_types.Type) - py_typecheck.check_type(shape_type, computation_types.Type) - - if isinstance(type_sig, computation_types.TensorType): - # This condition applies to both `shape_type` being a tensor or structure, - # as the same integer bitwidth can be used for all values in the structure. - return ( - np.issubdtype(type_sig.dtype, np.integer) - and array_shape.num_elements_in_shape(type_sig.shape) == 1 - ) - elif isinstance(shape_type, computation_types.StructType) and isinstance( - type_sig, computation_types.StructType - ): - bitwidth_name_and_types = list(type_sig.items()) - shape_name_and_types = list(shape_type.items()) - if len(type_sig) != len(shape_name_and_types): - return False - for (inner_name, type_sig), (inner_shape_name, inner_shape_type) in zip( - bitwidth_name_and_types, shape_name_and_types - ): - if inner_name != inner_shape_name: - return False - if not is_single_integer_or_matches_structure(type_sig, inner_shape_type): - return False - return True - else: - return False - - -def check_federated_type( - type_spec: computation_types.FederatedType, - member: Optional[computation_types.Type] = None, - placement: Optional[placements.PlacementLiteral] = None, - all_equal: Optional[bool] = None, -): - """Checks that `type_spec` is a federated type with the given parameters. - - Args: - type_spec: The `tff.FederatedType` to check. - member: The expected member type, or `None` if unspecified. - placement: The desired placement, or `None` if unspecified. - all_equal: The desired result of accessing the property - `tff.FederatedType.all_equal` of `type_spec`, or `None` if left - unspecified. - - Raises: - TypeError: if `type_spec` is not a federated type of the given kind. - """ - py_typecheck.check_type(type_spec, computation_types.FederatedType) - if member is not None: - py_typecheck.check_type(member, computation_types.Type) - member.check_assignable_from(type_spec.member) - if placement is not None: - py_typecheck.check_type(placement, placements.PlacementLiteral) - if type_spec.placement is not placement: - raise TypeError( - 'Expected federated type placed at {}, got one placed at {}.'.format( - placement, type_spec.placement - ) - ) - if all_equal is not None: - py_typecheck.check_type(all_equal, bool) - if type_spec.all_equal != all_equal: - raise TypeError( - 'Expected federated type with all_equal {}, got one with {}.'.format( - all_equal, type_spec.all_equal - ) - ) - - -def is_average_compatible(type_spec: computation_types.Type) -> bool: - """Determines if `type_spec` can be averaged. - - Types that are average-compatible are composed of numeric tensor types, - either floating-point or complex, possibly packaged into nested named tuples, - and possibly federated. - - Args: - type_spec: a `computation_types.Type`. - - Returns: - `True` iff `type_spec` is average-compatible, `False` otherwise. - """ - py_typecheck.check_type(type_spec, computation_types.Type) - if isinstance(type_spec, computation_types.TensorType): - return np.issubdtype(type_spec, np.inexact) - elif isinstance(type_spec, computation_types.StructType): - return all(is_average_compatible(v) for _, v in type_spec.items()) - elif isinstance(type_spec, computation_types.FederatedType): - return is_average_compatible(type_spec.member) - else: - return False - - -def is_min_max_compatible(type_spec: computation_types.Type) -> bool: - """Determines if `type_spec` is min/max compatible. - - Types that are min/max-compatible are composed of integer or floating tensor - types, possibly packaged into nested tuples and possibly federated. - - Args: - type_spec: a `computation_types.Type`. - - Returns: - `True` iff `type_spec` is min/max compatible, `False` otherwise. - """ - if isinstance(type_spec, computation_types.TensorType): - return np.issubdtype(type_spec.dtype, np.integer) or np.issubdtype( - type_spec.dtype, np.floating - ) - elif isinstance(type_spec, computation_types.StructType): - return all(is_min_max_compatible(v) for _, v in type_spec.items()) - elif isinstance(type_spec, computation_types.FederatedType): - return is_min_max_compatible(type_spec.member) - else: - return False - - -def is_struct_with_py_container(value, type_spec): - return isinstance(value, structure.Struct) and isinstance( - type_spec, computation_types.StructWithPythonType - ) - - -class NotConcreteTypeError(TypeError): - - def __init__(self, full_type, found_abstract): - message = ( - 'Expected concrete type containing no abstract types, but ' - f'found abstract type {found_abstract} in {full_type}.' - ) - super().__init__(message) - - -class MismatchedConcreteTypesError(TypeError): - """Raised when there is a mismatch between two types.""" - - def __init__( - self, - full_concrete, - full_generic, - abstract_label, - first_concrete, - second_concrete, - ): - message = ( - f'Expected concrete type {full_concrete} to be a valid substitution ' - f'for generic type {full_generic}, but abstract type {abstract_label} ' - f'had substitutions {first_concrete} and {second_concrete}, which are ' - 'not equivalent.' - ) - super().__init__(message) - - -class UnassignableConcreteTypesError(TypeError): - """Raised when one type can not be assigned to another type.""" - - def __init__( - self, - full_concrete, - full_generic, - abstract_label, - definition, - not_assignable_from, - ): - message = ( - f'Expected concrete type {full_concrete} to be a valid substitution ' - f'for generic type {full_generic}, but abstract type {abstract_label} ' - f'was defined as {definition}, and later used as {not_assignable_from} ' - 'which cannot be assigned from the former.' - ) - super().__init__(message) - - -class MismatchedStructureError(TypeError): - """Raised when there is a mismatch between the structures of two types.""" - - def __init__( - self, - full_concrete, - full_generic, - concrete_member, - generic_member, - mismatch, - ): - message = ( - f'Expected concrete type {full_concrete} to be a valid substitution ' - f'for generic type {full_generic}, but their structures do not match: ' - f'{concrete_member} differs in {mismatch} from {generic_member}.' - ) - super().__init__(message) - - -class MissingDefiningUsageError(TypeError): - - def __init__(self, generic_type, label_name): - message = ( - f'Missing defining use of abstract type {label_name} in type ' - f'{generic_type}. See `check_concrete_instance_of` documentation for ' - 'details on what counts as a defining use.' - ) - super().__init__(message) - - -def check_concrete_instance_of( - concrete_type: computation_types.Type, generic_type: computation_types.Type -): - """Checks whether `concrete_type` is a valid substitution of `generic_type`. - - This function determines whether `generic_type`'s type parameters can be - substituted such that it is equivalent to `concrete type`. - - Note that passing through argument-position of function type swaps the - variance of abstract types. Argument-position types can be assigned *from* - other instances of the same type, but are not equivalent to it. - - Due to this variance issue, only abstract types must include at least one - "defining" usage. "Defining" uses are those which are encased in function - parameter position an odd number of times. These usages must all be - equivalent. Non-defining usages need not compare equal but must be assignable - *from* defining usages. - - Args: - concrete_type: A type containing no `computation_types.AbstractType`s to - check against `generic_type`'s shape. - generic_type: A type which may contain `computation_types.AbstractType`s. - - Raises: - TypeError: If `concrete_type` is not a valid substitution of `generic_type`. - """ - py_typecheck.check_type(concrete_type, computation_types.Type) - py_typecheck.check_type(generic_type, computation_types.Type) - - for t in preorder_types(concrete_type): - if isinstance(t, computation_types.AbstractType): - raise NotConcreteTypeError(concrete_type, t) - - type_bindings = {} - non_defining_usages = collections.defaultdict(list) - - def _check_helper( - generic_type_member: computation_types.Type, - concrete_type_member: computation_types.Type, - defining: bool, - ): - """Recursive helper function.""" - - def _raise_structural(mismatch): - raise MismatchedStructureError( - concrete_type, - generic_type, - concrete_type_member, - generic_type_member, - mismatch, - ) - - def _both_are(predicate): - if predicate(generic_type_member): - if predicate(concrete_type_member): - return True - else: - _raise_structural('kind') - else: - return False - - if isinstance(generic_type_member, computation_types.AbstractType): - label = str(generic_type_member.label) - if not defining: - non_defining_usages[label].append(concrete_type_member) - else: - bound_type = type_bindings.get(label) - if bound_type is not None: - if not concrete_type_member.is_equivalent_to(bound_type): - raise MismatchedConcreteTypesError( - concrete_type, - generic_type, - label, - bound_type, - concrete_type_member, - ) - else: - type_bindings[label] = concrete_type_member - elif _both_are(lambda t: isinstance(t, computation_types.TensorType)): - if generic_type_member != concrete_type_member: - _raise_structural('tensor types') - elif _both_are(lambda t: isinstance(t, computation_types.PlacementType)): - if generic_type_member != concrete_type_member: - _raise_structural('placements') - elif _both_are(lambda t: isinstance(t, computation_types.StructType)): - generic_elements = list(generic_type_member.items()) # pytype: disable=attribute-error - concrete_elements = list(concrete_type_member.items()) # pytype: disable=attribute-error - if len(generic_elements) != len(concrete_elements): - _raise_structural('length') - for generic_element, concrete_element in zip( - generic_elements, concrete_elements - ): - if generic_element[0] != concrete_element[0]: - _raise_structural('element names') - _check_helper(generic_element[1], concrete_element[1], defining) - elif _both_are(lambda t: isinstance(t, computation_types.SequenceType)): - _check_helper( - generic_type_member.element, # pytype: disable=attribute-error - concrete_type_member.element, # pytype: disable=attribute-error - defining, - ) - elif _both_are(lambda t: isinstance(t, computation_types.FunctionType)): - if generic_type_member.parameter is None: # pytype: disable=attribute-error - if concrete_type_member.parameter is not None: # pytype: disable=attribute-error - _raise_structural('parameter') - else: - _check_helper( - generic_type_member.parameter, # pytype: disable=attribute-error - concrete_type_member.parameter, # pytype: disable=attribute-error - not defining, - ) - _check_helper( - generic_type_member.result, # pytype: disable=attribute-error - concrete_type_member.result, # pytype: disable=attribute-error - defining, - ) - elif _both_are(lambda t: isinstance(t, computation_types.FederatedType)): - if generic_type_member.placement != concrete_type_member.placement: # pytype: disable=attribute-error - _raise_structural('placement') - if generic_type_member.all_equal != concrete_type_member.all_equal: # pytype: disable=attribute-error - _raise_structural('all equal') - _check_helper( - generic_type_member.member, # pytype: disable=attribute-error - concrete_type_member.member, # pytype: disable=attribute-error - defining, - ) - else: - raise TypeError(f'Unexpected type kind {generic_type}.') - - _check_helper(generic_type, concrete_type, False) - - for label, usages in non_defining_usages.items(): - bound_type = type_bindings.get(label) - if bound_type is None: - if len(usages) == 1: - # Single-use abstract types can't be wrong. - # Note: we could also add an exception here for cases where every usage - # is equivalent to the first usage. However, that's not currently - # needed since the only intrinsic that doesn't have a defining use is - # GENERIC_ZERO, which has only a single-use type parameter. - pass - else: - raise MissingDefiningUsageError(generic_type, label) - else: - for usage in usages: - if not usage.is_assignable_from(bound_type): - raise UnassignableConcreteTypesError( - concrete_type, generic_type, label, bound_type, usage - ) - - -def check_valid_federated_weighted_mean_argument_tuple_type( - type_spec: computation_types.StructType, -): - """Checks that `type_spec` is a valid type of a federated weighted mean arg. - - Args: - type_spec: A `computation_types.StructType`. - - Raises: - TypeError: If the check fails. - """ - py_typecheck.check_type(type_spec, computation_types.StructType) - if len(type_spec) != 2: - raise TypeError('Expected a 2-tuple, found {}.'.format(type_spec)) - for _, v in type_spec.items(): - check_federated_type(v, None, placements.CLIENTS, False) # pytype: disable=wrong-arg-types - if not is_average_compatible(v.member): # pytype: disable=attribute-error - raise TypeError( - 'Expected average-compatible args, got {} from argument of type {}.' - .format(v.member, type_spec) # pytype: disable=attribute-error - ) - w_type = type_spec[1].member - if ( - not isinstance(w_type, computation_types.TensorType) - or w_type.shape is None - or w_type.shape - ): - raise TypeError('Expected scalar weight, got {}.'.format(w_type)) - - -def count_tensors_in_type( - type_spec: computation_types.Type, - tensor_filter: Optional[ - Callable[[computation_types.TensorType], bool] - ] = None, -) -> collections.OrderedDict[str, int]: - """Counts tensors and fully-specified elements under `type_spec`. - - Args: - type_spec: Instance of `computation_types.Type` to count tensors under. - tensor_filter: Optional filtering function. Callable which takes an argument - of type `computation_types.TensorType` and returns a boolean. If - specified, only tensor type which pass this filter (IE, on which this - function returns `True`) will be counted. - - Returns: - A `collections.OrderedDict` with three parameters. The first, `tensors`, is - the count of all `computation_types.TensorType` (passing `tensor_filter` - if this argument is specified). The second, `parameters`, is the count - of all fully-specified parameters of these tensors. Note that this implies - any tensor with a `None` dimension (IE, of unspecified size) will not be - counted. The third counts how many tensors fall into this category (that - is, now many have unspecified size). - """ - py_typecheck.check_type(type_spec, computation_types.Type) - if tensor_filter is None: - tensor_filter = lambda _: True - - tensors_and_params = collections.OrderedDict( - num_tensors=0, parameters=0, num_unspecified_tensors=0 - ) - - def _capture_tensors(type_signature): - if isinstance( - type_signature, computation_types.TensorType - ) and tensor_filter(type_signature): - tensors_and_params['num_tensors'] += 1 - num_parameters = array_shape.num_elements_in_shape(type_signature.shape) - if num_parameters is not None: - tensors_and_params['parameters'] += num_parameters - else: - tensors_and_params['num_unspecified_tensors'] += 1 - return type_signature, False - - type_transformations.transform_type_postorder(type_spec, _capture_tensors) - return tensors_and_params diff --git a/tensorflow_federated/python/core/impl/types/type_analysis_test.py b/tensorflow_federated/python/core/impl/types/type_analysis_test.py deleted file mode 100644 index a3b71b7207..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_analysis_test.py +++ /dev/null @@ -1,893 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import collections - -from absl.testing import absltest -from absl.testing import parameterized -import ml_dtypes -import numpy as np - -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis - - -class CountTypesTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ( - 'one', - computation_types.TensorType(np.int32), - lambda t: isinstance(t, computation_types.TensorType), - 1, - ), - ( - 'three', - computation_types.StructType([np.int32] * 3), - lambda t: isinstance(t, computation_types.TensorType), - 3, - ), - ( - 'nested', - computation_types.StructType([[np.int32] * 3] * 3), - lambda t: isinstance(t, computation_types.TensorType), - 9, - ), - ]) - def test_returns_result(self, type_signature, predicate, expected_result): - result = type_analysis.count(type_signature, predicate) - self.assertEqual(result, expected_result) - - -class ContainsTypesTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ( - 'one_type', - computation_types.TensorType(np.int32), - computation_types.TensorType, - ), - ( - 'two_types', - computation_types.StructType([np.int32]), - (computation_types.StructType, computation_types.TensorType), - ), - ( - 'less_types', - computation_types.TensorType(np.int32), - (computation_types.StructType, computation_types.TensorType), - ), - ( - 'more_types', - computation_types.StructType([np.int32]), - computation_types.TensorType, - ), - ]) - def test_returns_true(self, type_signature, types): - result = type_analysis.contains( - type_signature, lambda x: isinstance(x, types) - ) - self.assertTrue(result) - - @parameterized.named_parameters([ - ( - 'one_type', - computation_types.TensorType(np.int32), - computation_types.StructType, - ), - ]) - def test_returns_false(self, type_signature, types): - result = type_analysis.contains( - type_signature, lambda x: isinstance(x, types) - ) - self.assertFalse(result) - - -class ContainsOnlyTypesTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ( - 'one_type', - computation_types.TensorType(np.int32), - computation_types.TensorType, - ), - ( - 'two_types', - computation_types.StructType([np.int32]), - (computation_types.StructType, computation_types.TensorType), - ), - ( - 'less_types', - computation_types.TensorType(np.int32), - (computation_types.StructType, computation_types.TensorType), - ), - ]) - def test_returns_true(self, type_signature, types): - result = type_analysis.contains_only( - type_signature, lambda x: isinstance(x, types) - ) - self.assertTrue(result) - - @parameterized.named_parameters([ - ( - 'one_type', - computation_types.TensorType(np.int32), - computation_types.StructType, - ), - ( - 'more_types', - computation_types.StructType([np.int32]), - computation_types.TensorType, - ), - ]) - def test_returns_false(self, type_signature, types): - result = type_analysis.contains_only( - type_signature, lambda x: isinstance(x, types) - ) - self.assertFalse(result) - - -class CheckAllAbstractTypesAreBoundTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ('tensor_type', computation_types.TensorType(np.int32)), - ( - 'function_type_with_no_arg', - computation_types.FunctionType(None, np.int32), - ), - ( - 'function_type_with_int_arg', - computation_types.FunctionType(np.int32, np.int32), - ), - ( - 'function_type_with_abstract_arg', - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('T'), - ), - ), - ( - 'tuple_tuple_function_type_with_abstract_arg', - computation_types.StructType([ - computation_types.StructType([ - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('T'), - ), - ]) - ]), - ), - ( - 'function_type_with_unbound_function_arg', - computation_types.FunctionType( - computation_types.FunctionType( - None, computation_types.AbstractType('T') - ), - computation_types.AbstractType('T'), - ), - ), - ( - 'function_type_with_sequence_arg', - computation_types.FunctionType( - computation_types.SequenceType( - computation_types.AbstractType('T') - ), - np.int32, - ), - ), - ( - 'function_type_with_two_abstract_args', - computation_types.FunctionType( - computation_types.StructType([ - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ]), - computation_types.StructType([ - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ]), - ), - ), - ]) - def test_does_not_raise_type_error(self, type_spec): - try: - type_analysis.check_all_abstract_types_are_bound(type_spec) - except TypeError: - self.fail('Raised `TypeError` unexpectedly.') - - @parameterized.named_parameters([ - ('abstract_type', computation_types.AbstractType('T')), - ( - 'function_type_with_no_arg', - computation_types.FunctionType( - None, computation_types.AbstractType('T') - ), - ), - ( - 'function_type_with_int_arg', - computation_types.FunctionType( - np.int32, computation_types.AbstractType('T') - ), - ), - ( - 'function_type_with_abstract_arg', - computation_types.FunctionType( - computation_types.AbstractType('T'), - computation_types.AbstractType('U'), - ), - ), - ]) - def test_raises_type_error(self, type_spec): - with self.assertRaises(TypeError): - type_analysis.check_all_abstract_types_are_bound(type_spec) - - -class CheckIsSumCompatibleTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ('tensor_type', computation_types.TensorType(np.int32)), - ('bfloat16_type', computation_types.TensorType(ml_dtypes.bfloat16)), - ( - 'struct_type_int', - computation_types.StructType( - [np.int32, np.int32], - ), - ), - ( - 'struct_type_float', - computation_types.StructType([np.complex128, np.float32, np.float64]), - ), - ( - 'federated_type', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ]) - def test_does_not_raise_sum_incompatible_error(self, type_spec): - try: - type_analysis.check_is_sum_compatible(type_spec) - except type_analysis.SumIncompatibleError: - self.fail('Raised `SumIncompatibleError` unexpectedly.') - - @parameterized.named_parameters([ - ('tensor_type_bool', computation_types.TensorType(np.bool_)), - ('tensor_type_string', computation_types.TensorType(np.str_)), - ( - 'partially_defined_shape', - computation_types.TensorType(np.int32, shape=[None]), - ), - ('tuple_type', computation_types.StructType([np.int32, np.bool_])), - ('sequence_type', computation_types.SequenceType(np.int32)), - ('placement_type', computation_types.PlacementType()), - ('function_type', computation_types.FunctionType(np.int32, np.int32)), - ('abstract_type', computation_types.AbstractType('T')), - ]) - def test_raises_sum_incompatible_error(self, type_spec): - with self.assertRaises(type_analysis.SumIncompatibleError): - type_analysis.check_is_sum_compatible(type_spec) - - -class IsAverageCompatibleTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ('tensor_type_float32', computation_types.TensorType(np.float32)), - ('tensor_type_float64', computation_types.TensorType(np.float64)), - ( - 'tuple_type', - computation_types.StructType([('x', np.float32), ('y', np.float64)]), - ), - ( - 'federated_type', - computation_types.FederatedType(np.float32, placements.CLIENTS), - ), - ]) - def test_returns_true(self, type_spec): - self.assertTrue(type_analysis.is_average_compatible(type_spec)) - - @parameterized.named_parameters([ - ('tensor_type_int32', computation_types.TensorType(np.int32)), - ('tensor_type_int64', computation_types.TensorType(np.int64)), - ('sequence_type', computation_types.SequenceType(np.float32)), - ]) - def test_returns_false(self, type_spec): - self.assertFalse(type_analysis.is_average_compatible(type_spec)) - - -class IsMinMaxCompatibleTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ('tensor_type_int32', computation_types.TensorType(np.int32)), - ('tensor_type_int64', computation_types.TensorType(np.int64)), - ('tensor_type_float32', computation_types.TensorType(np.float32)), - ('tensor_type_float64', computation_types.TensorType(np.float64)), - ( - 'tuple_type', - computation_types.StructType([('x', np.float32), ('y', np.int64)]), - ), - ( - 'federated_type', - computation_types.FederatedType(np.float32, placements.CLIENTS), - ), - ]) - def test_returns_true(self, type_spec): - self.assertTrue(type_analysis.is_min_max_compatible(type_spec)) - - @parameterized.named_parameters([ - ('tensor_type_complex', computation_types.TensorType(np.complex128)), - ('sequence_type', computation_types.SequenceType(np.float32)), - ]) - def test_returns_false(self, type_spec): - self.assertFalse(type_analysis.is_min_max_compatible(type_spec)) - - -class CheckTypeTest(absltest.TestCase): - - def test_raises_type_error(self): - type_analysis.check_type(10, computation_types.TensorType(np.int32)) - self.assertRaises(TypeError, type_analysis.check_type, 10, np.bool_) - - -class CheckFederatedTypeTest(absltest.TestCase): - - def test_passes_or_raises_type_error(self): - type_spec = computation_types.FederatedType( - np.int32, placements.CLIENTS, False - ) - type_analysis.check_federated_type( - type_spec, - computation_types.TensorType(np.int32), - placements.CLIENTS, - False, - ) - type_analysis.check_federated_type( - type_spec, computation_types.TensorType(np.int32), None, None - ) - type_analysis.check_federated_type( - type_spec, None, placements.CLIENTS, None - ) - type_analysis.check_federated_type(type_spec, None, None, False) - self.assertRaises( - TypeError, - type_analysis.check_federated_type, - type_spec, - np.bool_, - None, - None, - ) - self.assertRaises( - TypeError, - type_analysis.check_federated_type, - type_spec, - None, - placements.SERVER, - None, - ) - self.assertRaises( - TypeError, - type_analysis.check_federated_type, - type_spec, - None, - None, - True, - ) - - -class IsStructureOfFloatsTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('empty_struct', computation_types.StructType([])), - ('float', computation_types.TensorType(np.float32)), - ('floats', computation_types.StructType([np.float32, np.float32])), - ( - 'nested_struct', - computation_types.StructType([ - computation_types.TensorType(np.float32), - computation_types.StructType([np.float32, np.float32]), - ]), - ), - ( - 'federated_float_at_clients', - computation_types.FederatedType(np.float32, placements.CLIENTS), - ), - ) - def test_returns_true(self, type_spec): - self.assertTrue(type_analysis.is_structure_of_floats(type_spec)) - - @parameterized.named_parameters( - ('bool', computation_types.TensorType(np.bool_)), - ('int', computation_types.TensorType(np.int32)), - ('string', computation_types.TensorType(np.str_)), - ('float_and_bool', computation_types.StructType([np.float32, np.bool_])), - ( - 'nested_struct', - computation_types.StructType([ - computation_types.TensorType(np.float32), - computation_types.StructType([np.bool_, np.bool_]), - ]), - ), - ('sequence_of_floats', computation_types.SequenceType(np.float32)), - ('placement', computation_types.PlacementType()), - ('function', computation_types.FunctionType(np.float32, np.float32)), - ('abstract', computation_types.AbstractType('T')), - ) - def test_returns_false(self, type_spec): - self.assertFalse(type_analysis.is_structure_of_floats(type_spec)) - - -class IsStructureOfIntegersTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('empty_struct', computation_types.StructType([])), - ('int', computation_types.TensorType(np.int32)), - ('ints', computation_types.StructType([np.int32, np.int32])), - ( - 'nested_struct', - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.StructType([np.int32, np.int32]), - ]), - ), - ( - 'federated_int_at_clients', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ) - def test_returns_true(self, type_spec): - self.assertTrue(type_analysis.is_structure_of_integers(type_spec)) - - @parameterized.named_parameters( - ('bool', computation_types.TensorType(np.bool_)), - ('float', computation_types.TensorType(np.float32)), - ('string', computation_types.TensorType(np.str_)), - ('int_and_bool', computation_types.StructType([np.int32, np.bool_])), - ( - 'nested_struct', - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.StructType([np.bool_, np.bool_]), - ]), - ), - ('sequence_of_ints', computation_types.SequenceType(np.int32)), - ('placement', computation_types.PlacementType()), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('abstract', computation_types.AbstractType('T')), - ) - def test_returns_false(self, type_spec): - self.assertFalse(type_analysis.is_structure_of_integers(type_spec)) - - -class IsSingleIntegerOrMatchesStructure(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'single int', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ( - 'struct', - computation_types.StructType([np.int32, np.int32]), - computation_types.StructType([np.int32, np.int32]), - ), - ( - 'struct with named fields', - computation_types.StructType([('x', np.int32)]), - computation_types.StructType([('x', np.int32)]), - ), - ( - 'single int for complex tensor', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32, [5, 97, 204]), - ), - ( - 'different kinds of ints', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int8), - ), - ( - 'single int_for_struct', - computation_types.TensorType(np.int32), - computation_types.StructType([np.int32, np.int32]), - ), - ) - def test_returns_true(self, type_sig, shape_type): - self.assertTrue( - type_analysis.is_single_integer_or_matches_structure( - type_sig, shape_type - ) - ) - - @parameterized.named_parameters( - ( - 'miscounted struct', - computation_types.StructType([np.int32, np.int32, np.int32]), - computation_types.StructType([np.int32, np.int32]), - ), - ( - 'miscounted struct 2', - computation_types.StructType([np.int32, np.int32]), - computation_types.StructType([np.int32, np.int32, np.int32]), - ), - ( - 'misnamed struct', - computation_types.StructType([('x', np.int32)]), - computation_types.StructType([('y', np.int32)]), - ), - ) - def test_returns_false(self, type_sig, shape_type): - self.assertFalse( - type_analysis.is_single_integer_or_matches_structure( - type_sig, shape_type - ) - ) - - -class IsAnonTupleWithPyContainerTest(absltest.TestCase): - - def test_returns_true(self): - value = structure.Struct([('a', 0.0)]) - type_spec = computation_types.StructWithPythonType( - [('a', np.float32)], dict - ) - self.assertTrue(type_analysis.is_struct_with_py_container(value, type_spec)) - - def test_returns_false_with_none_value(self): - value = None - type_spec = computation_types.StructWithPythonType( - [('a', np.float32)], dict - ) - self.assertFalse( - type_analysis.is_struct_with_py_container(value, type_spec) - ) - - def test_returns_false_with_named_tuple_type_spec(self): - value = structure.Struct([('a', 0.0)]) - type_spec = computation_types.StructType([('a', np.float32)]) - self.assertFalse( - type_analysis.is_struct_with_py_container(value, type_spec) - ) - - -class CheckConcreteInstanceOf(absltest.TestCase): - - def test_raises_with_int_first_argument(self): - with self.assertRaises(TypeError): - type_analysis.check_concrete_instance_of( - 1, computation_types.TensorType(np.int32) - ) - - def test_raises_with_int_second_argument(self): - with self.assertRaises(TypeError): - type_analysis.check_concrete_instance_of( - computation_types.TensorType(np.int32), 1 - ) - - def test_raises_different_structures(self): - with self.assertRaises(type_analysis.MismatchedStructureError): - type_analysis.check_concrete_instance_of( - computation_types.TensorType(np.int32), - computation_types.StructType([np.int32]), - ) - - def test_raises_with_abstract_type_as_first_arg(self): - t1 = computation_types.AbstractType('T1') - t2 = computation_types.TensorType(np.int32) - with self.assertRaises(type_analysis.NotConcreteTypeError): - type_analysis.check_concrete_instance_of(t1, t2) - - def test_with_single_abstract_type_and_tensor_type(self): - t1 = computation_types.AbstractType('T1') - t2 = computation_types.TensorType(np.int32) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_raises_with_abstract_type_in_first_and_second_argument(self): - t1 = computation_types.AbstractType('T1') - t2 = computation_types.AbstractType('T2') - with self.assertRaises(type_analysis.NotConcreteTypeError): - type_analysis.check_concrete_instance_of(t2, t1) - - def func_with_param(self, param_type): - return computation_types.FunctionType( - param_type, computation_types.StructType([]) - ) - - def test_with_single_abstract_type_and_tuple_type(self): - t1 = self.func_with_param(computation_types.AbstractType('T1')) - t2 = self.func_with_param(computation_types.StructType([np.int32])) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_raises_with_conflicting_names(self): - t1 = computation_types.StructType([np.int32] * 2) - t2 = computation_types.StructType([('a', np.int32), ('b', np.int32)]) - with self.assertRaises(type_analysis.MismatchedStructureError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_raises_with_different_lengths(self): - t1 = computation_types.StructType([np.int32] * 2) - t2 = computation_types.StructType([np.int32]) - with self.assertRaises(type_analysis.MismatchedStructureError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_succeeds_under_tuple(self): - t1 = self.func_with_param( - computation_types.StructType([computation_types.AbstractType('T1')] * 2) - ) - t2 = self.func_with_param( - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ]) - ) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_fails_under_tuple_conflicting_concrete_types(self): - t1 = self.func_with_param( - computation_types.StructType([computation_types.AbstractType('T1')] * 2) - ) - t2 = self.func_with_param( - computation_types.StructType([ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.float32), - ]) - ) - with self.assertRaises(type_analysis.MismatchedConcreteTypesError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_succeeds_abstract_type_under_sequence_type(self): - t1 = self.func_with_param( - computation_types.SequenceType(computation_types.AbstractType('T')) - ) - t2 = self.func_with_param(computation_types.SequenceType(np.int32)) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_fails_conflicting_concrete_types_under_sequence(self): - t1 = self.func_with_param( - computation_types.SequenceType( - [computation_types.AbstractType('T')] * 2 - ) - ) - t2 = self.func_with_param( - computation_types.SequenceType([np.int32, np.float32]) - ) - with self.assertRaises(type_analysis.MismatchedConcreteTypesError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_succeeds_single_function_type(self): - t1 = computation_types.FunctionType( - *[computation_types.AbstractType('T')] * 2 - ) - t2 = computation_types.FunctionType(np.int32, np.int32) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_succeeds_function_different_parameter_and_return_types(self): - t1 = computation_types.FunctionType( - computation_types.StructType([ - computation_types.AbstractType('U'), - computation_types.AbstractType('T'), - ]), - computation_types.AbstractType('T'), - ) - t2 = computation_types.FunctionType( - computation_types.StructType([np.int32, np.float32]), np.float32 - ) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_fails_conflicting_binding_in_parameter_and_result(self): - t1 = computation_types.FunctionType( - computation_types.AbstractType('T'), computation_types.AbstractType('T') - ) - t2 = computation_types.FunctionType(np.int32, np.float32) - with self.assertRaises(type_analysis.UnassignableConcreteTypesError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_abstract_federated_types_succeeds(self): - t1 = self.func_with_param( - computation_types.FederatedType( - [computation_types.AbstractType('T1')] * 2, - placements.CLIENTS, - all_equal=True, - ) - ) - t2 = self.func_with_param( - computation_types.FederatedType( - [np.int32] * 2, placements.CLIENTS, all_equal=True - ) - ) - type_analysis.check_concrete_instance_of(t2, t1) - - def test_abstract_fails_on_different_federated_placements(self): - t1 = self.func_with_param( - computation_types.FederatedType( - [computation_types.AbstractType('T1')] * 2, - placements.CLIENTS, - all_equal=True, - ) - ) - t2 = self.func_with_param( - computation_types.FederatedType( - [np.int32] * 2, placements.SERVER, all_equal=True - ) - ) - with self.assertRaises(type_analysis.MismatchedStructureError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_abstract_can_be_concretized_fails_on_different_placements(self): - t1 = self.func_with_param( - computation_types.FederatedType( - [computation_types.AbstractType('T1')] * 2, - placements.CLIENTS, - all_equal=True, - ) - ) - t2 = self.func_with_param( - computation_types.FederatedType( - [np.int32] * 2, placements.SERVER, all_equal=True - ) - ) - with self.assertRaises(type_analysis.MismatchedStructureError): - type_analysis.check_concrete_instance_of(t2, t1) - - def test_abstract_parameters_contravariant(self): - struct = lambda name: computation_types.StructType([(name, np.int32)]) - unnamed = struct(None) - concrete = computation_types.FunctionType( - computation_types.StructType( - [unnamed, computation_types.FunctionType(struct('bar'), unnamed)] - ), - struct('foo'), - ) - abstract = computation_types.AbstractType('A') - generic = computation_types.FunctionType( - computation_types.StructType( - [abstract, computation_types.FunctionType(abstract, abstract)] - ), - abstract, - ) - type_analysis.check_concrete_instance_of(concrete, generic) - - -class IsBinaryOpWithUpcastCompatibleTest(absltest.TestCase): - - def test_fails_on_none(self): - self.assertFalse( - type_analysis.is_binary_op_with_upcast_compatible_pair(None, None) - ) - - def test_passes_empty_tuples(self): - self.assertTrue( - type_analysis.is_binary_op_with_upcast_compatible_pair( - computation_types.StructType([]), computation_types.StructType([]) - ) - ) - - def test_fails_scalars_different_dtypes(self): - self.assertFalse( - type_analysis.is_binary_op_with_upcast_compatible_pair( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.float32), - ) - ) - - def test_passes_named_tuple_and_compatible_scalar(self): - self.assertTrue( - type_analysis.is_binary_op_with_upcast_compatible_pair( - computation_types.StructType( - [('a', computation_types.TensorType(np.int32, [2, 2]))] - ), - computation_types.TensorType(np.int32), - ) - ) - - def test_fails_named_tuple_and_incompatible_scalar(self): - self.assertFalse( - type_analysis.is_binary_op_with_upcast_compatible_pair( - computation_types.StructType( - [('a', computation_types.TensorType(np.int32, [2, 2]))] - ), - computation_types.TensorType(np.float32), - ) - ) - - def test_fails_compatible_scalar_and_named_tuple(self): - self.assertFalse( - type_analysis.is_binary_op_with_upcast_compatible_pair( - computation_types.TensorType(np.float32), - computation_types.StructType( - [('a', computation_types.TensorType(np.int32, [2, 2]))] - ), - ) - ) - - def test_fails_named_tuple_type_and_non_scalar_tensor(self): - self.assertFalse( - type_analysis.is_binary_op_with_upcast_compatible_pair( - computation_types.StructType( - [('a', computation_types.TensorType(np.int32, [2, 2]))] - ), - computation_types.TensorType(np.int32, [2]), - ) - ) - - -class TestCheckValidFederatedWeightedMeanArgumentTupleTypeTest( - absltest.TestCase -): - - def test_raises_type_error(self): - type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( - computation_types.StructType( - [computation_types.FederatedType(np.float32, placements.CLIENTS)] - * 2 - ) - ) - with self.assertRaises(TypeError): - type_analysis.check_valid_federated_weighted_mean_argument_tuple_type( - computation_types.StructType( - [computation_types.FederatedType(np.int32, placements.CLIENTS)] - * 2 - ) - ) - - -class CountTensorsInTypeTest(absltest.TestCase): - - def test_raises_non_type(self): - with self.assertRaises(TypeError): - type_analysis.count_tensors_in_type(0) - - def test_counts_all_tensors_no_filter(self): - struct_type = computation_types.StructType([ - ('a', computation_types.TensorType(np.int32, shape=[2, 2])), - ('b', computation_types.TensorType(np.int32, shape=[2, 1])), - ]) - - tensors_and_param_count = type_analysis.count_tensors_in_type(struct_type) - - expected_tensors_and_param_count = collections.OrderedDict( - num_tensors=2, parameters=6, num_unspecified_tensors=0 - ) - self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count) - - def test_skips_unspecified_params(self): - struct_type = computation_types.StructType([ - ('a', computation_types.TensorType(np.int32, shape=[2, 2])), - ('b', computation_types.TensorType(np.int32, shape=[None, 1])), - ]) - - tensors_and_param_count = type_analysis.count_tensors_in_type(struct_type) - - expected_tensors_and_param_count = collections.OrderedDict( - num_tensors=2, parameters=4, num_unspecified_tensors=1 - ) - self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count) - - def test_tensor_filter_only_counts_matching_tensors(self): - struct_type = computation_types.StructType([ - ('a', computation_types.TensorType(np.float32, shape=[2, 2])), - ('b', computation_types.TensorType(np.int32, shape=[2, 1])), - ]) - tensor_filter = lambda tensor_type: tensor_type.dtype == np.float32 - - tensors_and_param_count = type_analysis.count_tensors_in_type( - struct_type, tensor_filter - ) - - expected_tensors_and_param_count = collections.OrderedDict( - num_tensors=1, parameters=4, num_unspecified_tensors=0 - ) - self.assertEqual(tensors_and_param_count, expected_tensors_and_param_count) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/type_conversions.py b/tensorflow_federated/python/core/impl/types/type_conversions.py deleted file mode 100644 index ebd2df0cc3..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_conversions.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# limitations under the License. -"""Utilities for type conversion, type checking, type inference, etc.""" - -import collections -from collections.abc import Hashable, Mapping, Sequence -import functools -import typing -from typing import Optional, Union - -import attrs -import numpy as np -import tree - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import dtype_utils -from tensorflow_federated.python.core.impl.types import typed_object - - -def infer_type(arg: object) -> Optional[computation_types.Type]: - """Infers the TFF type of the argument (a `computation_types.Type` instance). - - Warning: This function is only partially implemented. - - The kinds of arguments that are currently correctly recognized: - * tensors, variables, and data sets - * things that are convertible to tensors (including `numpy` arrays, builtin - types, as well as `list`s and `tuple`s of any of the above, etc.) - * nested lists, `tuple`s, `namedtuple`s, anonymous `tuple`s, `dict`, - `OrderedDict`s, `attrs` classes, and `tff.TypedObject`s - - Args: - arg: The argument, the TFF type of which to infer. - - Returns: - Either an instance of `computation_types.Type`, or `None` if the argument is - `None`. - """ - # TODO: b/224484886 - Downcasting to all handled types. - arg = typing.cast( - Union[ - None, - typed_object.TypedObject, - structure.Struct, - py_typecheck.SupportsNamedTuple, - Mapping[Hashable, object], - tuple[object, ...], - list[object], - ], - arg, - ) - if arg is None: - return None - elif isinstance(arg, typed_object.TypedObject): - return arg.type_signature - elif isinstance(arg, structure.Struct): - return computation_types.StructType([ - (k, infer_type(v)) if k else infer_type(v) - for k, v in structure.iter_elements(arg) - ]) - elif attrs.has(type(arg)): - items = attrs.asdict(arg, recurse=False).items() - return computation_types.StructWithPythonType( - [(k, infer_type(v)) for k, v in items], type(arg) - ) - elif isinstance(arg, py_typecheck.SupportsNamedTuple): - items = arg._asdict().items() - return computation_types.StructWithPythonType( - [(k, infer_type(v)) for k, v in items], type(arg) - ) - elif isinstance(arg, Mapping): - items = arg.items() - return computation_types.StructWithPythonType( - [(k, infer_type(v)) for k, v in items], type(arg) - ) - elif isinstance(arg, (tuple, list)): - elements = [] - all_elements_named = True - for element in arg: - all_elements_named &= py_typecheck.is_name_value_pair( - element, name_type=str - ) - elements.append(infer_type(element)) - # If this is a tuple of (name, value) pairs, the caller most likely intended - # this to be a StructType, so we avoid storing the Python container. - if elements and all_elements_named: - return computation_types.StructType(elements) - else: - return computation_types.StructWithPythonType(elements, type(arg)) - elif isinstance(arg, (np.ndarray, np.generic)): - return computation_types.TensorType(arg.dtype, arg.shape) - elif isinstance(arg, (bool, int, float, complex, str, bytes)): - dtype = dtype_utils.infer_dtype(arg) - return computation_types.TensorType(dtype) - else: - raise NotImplementedError(f'Unexpected type found: {type(arg)}.') - - -def to_structure_with_type( - obj: object, type_spec: computation_types.Type -) -> object: - """Converts the containers in `obj` to those defined by `type_spec`. - - Note: This function does not convert any leaf in the structure. - - For example: - - >>> obj = 1 - >>> type_spec = tff.TensorType(np.int32) - >>> tff.types.to_structure_with_type(obj, type_spec) - 1 - - >>> obj = [1, 2, 3] - >>> type_spec = tff.StructType([np.int32] * 3) - >>> tff.types.to_structure_with_type(obj, type_spec) - [1, 2, 3] - - >>> obj = [1, 2, 3] - >>> type_spec = tff.StructType([ - >>> ('a', np.int32), - >>> ('b', np.int32), - >>> ('c', np.int32), - >>> ]) - >>> tff.types.to_structure_with_type(obj, type_spec) - {'a': 1, 'b': 2, 'c': 3} - - Args: - obj: A Python value. - type_spec: The `tff.Type` to use convert `obj`. - - Returns: - A Python value equivalent to `obj` with a structure matching `type_spec`. - - Raises: - ValueError: If `obj` and `type_spec` do not match or a container does not - have either all named or unnamed elements. - """ - if not tree.is_nested(obj): - return obj - - def _get_item( - type_spec: computation_types.Type, key: Union[str, int] - ) -> Union[computation_types.FederatedType, computation_types.StructType]: - if isinstance(type_spec, computation_types.FederatedType): - type_spec = type_spec.member - if not isinstance(type_spec, computation_types.StructType): - raise ValueError( - 'Expected `type_spec` to be a `tff.StructType`, found' - f' {type(type_spec)}.' - ) - - return type_spec[key] - - def _to_structure(path: tuple[Union[str, int], ...], obj: object) -> object: - if tree.is_nested(obj): - container_type = functools.reduce(_get_item, path, type_spec) - if isinstance(container_type, computation_types.FederatedType): - container_type = container_type.member - if not isinstance(container_type, computation_types.StructType): - raise ValueError( - 'Expected `container_type` to be a `tff.StructType`, found' - f' {type(container_type)}.' - ) - - container_cls = container_type.python_container - if container_cls is None: - names = [name is not None for name, _ in container_type.items()] - if any(names): - if not all(names): - raise ValueError( - 'Expected `container_type` to have either all named or unnamed' - f' elements, found {container_type}.' - ) - container_cls = dict - else: - container_cls = list - - if isinstance(obj, py_typecheck.SupportsNamedTuple): - elements = obj._asdict().values() - elif isinstance(obj, Mapping): - elements = obj.values() - elif isinstance(obj, Sequence): - elements = obj - else: - raise ValueError( - 'Expected `obj` to be a `NamedTuple`, `Mapping`, or `Sequence`,' - f' found {type(obj)}.' - ) - - if isinstance(container_cls, py_typecheck.SupportsNamedTuple): - names = [name for name, _ in container_type.items()] - return container_cls(**dict(zip(names, elements))) - elif issubclass(container_cls, Mapping): - names = [name for name, _ in container_type.items()] - return container_cls(zip(names, elements)) # pylint: disable=too-many-function-args - elif issubclass(container_cls, Sequence): - return container_cls(elements) # pylint: disable=too-many-function-args - else: - raise ValueError( - 'Expected `container_cls` to be a `NamedTuple`, `Mapping`, or' - f' `Sequence`, found {container_cls}.' - ) - else: - return None - - return tree.traverse_with_path(_to_structure, obj, top_down=False) - - -def _is_container_type_without_names(container_type: type[object]) -> bool: - """Returns whether `container_type`'s elements are unnamed.""" - return issubclass(container_type, (list, tuple)) and not isinstance( - container_type, py_typecheck.SupportsNamedTuple - ) - - -def _is_container_type_with_names(container_type: type[object]) -> bool: - """Returns whether `container_type`'s elements are named.""" - return ( - isinstance(container_type, py_typecheck.SupportsNamedTuple) - or attrs.has(container_type) - or issubclass(container_type, dict) - ) - - -def type_to_py_container(value, type_spec: computation_types.Type): - """Recursively convert `structure.Struct`s to Python containers. - - This is in some sense the inverse operation to - `structure.from_container`. - - This function assumes some unique behavior with regards to `tff.SequenceType`. - If the `value` is a list, it may yield other `tff.StructTypes` as well as - Python types. Otherwise, it may only yield Python types. - - Args: - value: A structure of anonymous tuples of values corresponding to - `type_spec`. - type_spec: The `tff.Type` to which value should conform, possibly including - `computation_types.StructWithPythonType`. - - Returns: - The input value, with containers converted to appropriate Python - containers as specified by the `type_spec`. - - Raises: - ValueError: If the conversion is not possible due to a mix of named - and unnamed values, or if `value` contains names that are mismatched or - not present in the corresponding index of `type_spec`. - """ - if isinstance(type_spec, computation_types.FederatedType): - if type_spec.all_equal: - structure_type_spec = type_spec.member - else: - if not isinstance(value, list): - raise TypeError( - 'Unexpected Python type for non-all-equal TFF type ' - f'{type_spec}: expected `list`, found `{type(value)}`.' - ) - return [ - type_to_py_container(element, type_spec.member) for element in value - ] - else: - structure_type_spec = type_spec - - if isinstance(structure_type_spec, computation_types.SequenceType): - element_type = structure_type_spec.element - if isinstance(value, list): - return [type_to_py_container(element, element_type) for element in value] - else: - # Assume that the type of value does not understand `Struct` and that the - # value must yield Python containers. - return value - - if not isinstance(structure_type_spec, computation_types.StructType): - return value - - if not isinstance(value, structure.Struct): - # NOTE: When encountering non-`structure.Struct`s, we assume that - # this means that we're attempting to re-convert a value that - # already has the proper containers, and we short-circuit to - # avoid re-converting. This is a possibly dangerous assumption. - return value - - container_type = structure_type_spec.python_container - - # Ensure that names are only added, not mismatched or removed - names_from_value = structure.name_list_with_nones(value) - names_from_type_spec = structure.name_list_with_nones(structure_type_spec) - for value_name, type_name in zip(names_from_value, names_from_type_spec): - if value_name is not None: - if value_name != type_name: - raise ValueError( - f'Cannot convert value with field name `{value_name}` into a ' - f'type with field name `{type_name}`.' - ) - - num_named_elements = len(dir(structure_type_spec)) - num_unnamed_elements = len(structure_type_spec) - num_named_elements - if num_named_elements > 0 and num_unnamed_elements > 0: - raise ValueError( - f'Cannot represent value {value} with a Python container because it ' - 'contains a mix of named and unnamed elements.\n\nNote: this was ' - 'previously allowed when using the `tff.structure.Struct` container. ' - 'This support has been removed: please change to use structures with ' - 'either all-named or all-unnamed fields.' - ) - if container_type is None: - if num_named_elements: - container_type = collections.OrderedDict - else: - container_type = tuple - - # Avoid projecting the `structure.StructType`d TFF value into a Python - # container that is not supported. - if num_named_elements > 0 and _is_container_type_without_names( - container_type - ): - raise ValueError( - 'Cannot represent value {} with named elements ' - "using container type {} which does not support names. In TFF's " - 'typesystem, this corresponds to an implicit downcast'.format( - value, container_type - ) - ) - if _is_container_type_with_names(container_type) and len( - dir(structure_type_spec) - ) != len(value): - # If the type specifies the names, we have all the information we need. - # Otherwise we must raise here. - raise ValueError( - 'When packaging as a Python value which requires names, ' - 'the TFF type spec must have all names specified. Found ' - '{} names in type spec {} of length {}, with requested' - 'python type {}.'.format( - len(dir(structure_type_spec)), - structure_type_spec, - len(value), - container_type, - ) - ) - - elements = [] - for index, (elem_name, elem_type) in enumerate(structure_type_spec.items()): - element = type_to_py_container(value[index], elem_type) - - if elem_name is None: - elements.append(element) - else: - elements.append((elem_name, element)) - - if ( - isinstance(container_type, py_typecheck.SupportsNamedTuple) - or attrs.has(container_type) - ): - # The namedtuple and attr.s class constructors cannot interpret a list of - # (name, value) tuples; instead call constructor using kwargs. Note that - # these classes already define an order of names internally, so order does - # not matter. - return container_type(**dict(elements)) - else: - # E.g., tuple and list when elements only has values, but also `dict`, - # `collections.OrderedDict`, or `structure.Struct` when - # elements has (name, value) tuples. - return container_type(elements) # pytype: disable=wrong-arg-count - - -def type_to_non_all_equal(type_spec): - """Constructs a non-`all_equal` version of the federated type `type_spec`. - - Args: - type_spec: An instance of `tff.FederatedType`. - - Returns: - A federated type with the same member and placement, but `all_equal=False`. - """ - py_typecheck.check_type(type_spec, computation_types.FederatedType) - return computation_types.FederatedType( - type_spec.member, type_spec.placement, all_equal=False - ) diff --git a/tensorflow_federated/python/core/impl/types/type_conversions_test.py b/tensorflow_federated/python/core/impl/types/type_conversions_test.py deleted file mode 100644 index 784ed1fcc6..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_conversions_test.py +++ /dev/null @@ -1,596 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from collections.abc import Mapping -from typing import NamedTuple - -from absl.testing import absltest -from absl.testing import parameterized -import attrs -import numpy as np - -from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.core.impl.types import typed_object - - -class _TestTypedObject(typed_object.TypedObject): - - def __init__(self, type_signature: computation_types.Type): - self._type_signature = type_signature - - @property - def type_signature(self) -> computation_types.Type: - return self._type_signature - - -class _TestNamedTuple(NamedTuple): - a: int - b: int - c: int - - -class InferTypeTest(parameterized.TestCase): - - def test_with_none(self): - self.assertIsNone(type_conversions.infer_type(None)) - - def test_with_typed_object(self): - obj = _TestTypedObject(computation_types.TensorType(np.bool_)) - - whimsy_type = type_conversions.infer_type(obj) - self.assertEqual(whimsy_type.compact_representation(), 'bool') - - def test_with_scalar_int_tensor(self): - self.assertEqual(str(type_conversions.infer_type(np.int32(1))), 'int32') - self.assertEqual(str(type_conversions.infer_type(np.int64(1))), 'int64') - self.assertEqual(str(type_conversions.infer_type(np.int64(-1))), 'int64') - - def test_with_scalar_bool_tensor(self): - self.assertEqual(str(type_conversions.infer_type(np.bool_(False))), 'bool') - - def test_with_int_array_tensor(self): - self.assertEqual( - str(type_conversions.infer_type(np.array([10, 20], dtype=np.int32))), - 'int32[2]', - ) - self.assertEqual( - str( - type_conversions.infer_type( - np.array([0, 2**40, -(2**60), 0], dtype=np.int64) - ) - ), - 'int64[4]', - ) - - def test_with_int(self): - self.assertEqual(str(type_conversions.infer_type(10)), 'int32') - - def test_with_float(self): - self.assertEqual(str(type_conversions.infer_type(0.5)), 'float32') - - def test_with_bool(self): - self.assertEqual(str(type_conversions.infer_type(True)), 'bool') - - def test_with_string(self): - self.assertEqual(str(type_conversions.infer_type('abc')), 'str') - - def test_with_np_int32(self): - self.assertEqual(str(type_conversions.infer_type(np.int32(10))), 'int32') - - def test_with_np_int64(self): - self.assertEqual(str(type_conversions.infer_type(np.int64(10))), 'int64') - - def test_with_np_float32(self): - self.assertEqual( - str(type_conversions.infer_type(np.float32(10))), 'float32' - ) - - def test_with_np_float64(self): - self.assertEqual( - str(type_conversions.infer_type(np.float64(10))), 'float64' - ) - - def test_with_np_bool(self): - self.assertEqual(str(type_conversions.infer_type(np.bool_(True))), 'bool') - - def test_with_unicode_string(self): - self.assertEqual(str(type_conversions.infer_type('abc')), 'str') - - def test_with_numpy_int_array(self): - self.assertEqual( - str(type_conversions.infer_type(np.array([10, 20]))), 'int64[2]' - ) - - def test_with_numpy_nested_int_array(self): - self.assertEqual( - str(type_conversions.infer_type(np.array([[10], [20]]))), 'int64[2,1]' - ) - - def test_with_numpy_float64_scalar(self): - self.assertEqual(str(type_conversions.infer_type(np.float64(1))), 'float64') - - def test_with_int_list(self): - t = type_conversions.infer_type([1, 2, 3]) - self.assertEqual(str(t), '') - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, list) - - def test_with_nested_float_list(self): - t = type_conversions.infer_type([[0.1], [0.2], [0.3]]) - self.assertEqual(str(t), '<,,>') - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, list) - - def test_with_structure(self): - t = type_conversions.infer_type( - structure.Struct([ - ('a', 10), - (None, False), - ]) - ) - self.assertEqual(str(t), '') - self.assertIsInstance(t, computation_types.StructType) - self.assertNotIsInstance(t, computation_types.StructWithPythonType) - - def test_with_nested_structure(self): - t = type_conversions.infer_type( - structure.Struct([ - ('a', 10), - ( - None, - structure.Struct([ - (None, True), - (None, 0.5), - ]), - ), - ]) - ) - self.assertEqual(str(t), '>') - self.assertIsInstance(t, computation_types.StructType) - self.assertNotIsInstance(t, computation_types.StructWithPythonType) - - def test_with_namedtuple(self): - test_named_tuple = collections.namedtuple('TestNamedTuple', 'y x') - t = type_conversions.infer_type(test_named_tuple(1, True)) - self.assertEqual(str(t), '') - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, test_named_tuple) - - def test_with_dict(self): - v1 = { - 'a': 1, - 'b': 2.0, - } - inferred_type = type_conversions.infer_type(v1) - self.assertEqual(str(inferred_type), '') - self.assertIsInstance(inferred_type, computation_types.StructWithPythonType) - self.assertIs(inferred_type.python_container, dict) - - def test_with_ordered_dict(self): - t = type_conversions.infer_type( - collections.OrderedDict([('b', 2.0), ('a', 1)]) - ) - self.assertEqual(str(t), '') - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, collections.OrderedDict) - - def test_with_nested_attrs_class(self): - - @attrs.define - class TestAttrClass: - a: int - b: Mapping[str, object] - - t = type_conversions.infer_type( - TestAttrClass(a=0, b=collections.OrderedDict(x=True, y=0.0)) - ) - self.assertEqual(str(t), '>') - self.assertIsInstance(t, computation_types.StructWithPythonType) - self.assertIs(t.python_container, TestAttrClass) - self.assertIs(t.b.python_container, collections.OrderedDict) - - def test_with_empty_tuple(self): - t = type_conversions.infer_type(()) - self.assertEqual(t, computation_types.StructWithPythonType([], tuple)) - - -class ToStructureWithTypeTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('value', 1, computation_types.TensorType(np.int32), 1), - ( - 'list_to_list', - [1, 2, 3], - computation_types.StructType([np.int32] * 3), - [1, 2, 3], - ), - ( - 'list_to_dict', - [1, 2, 3], - computation_types.StructType([ - ('a', np.int32), - ('b', np.int32), - ('c', np.int32), - ]), - {'a': 1, 'b': 2, 'c': 3}, - ), - ( - 'list_to_named_tuple', - [1, 2, 3], - computation_types.StructWithPythonType( - [ - ('a', np.int32), - ('b', np.int32), - ('c', np.int32), - ], - container_type=_TestNamedTuple, - ), - _TestNamedTuple(1, 2, 3), - ), - ( - 'dict_to_list', - {'a': 1, 'b': 2, 'c': 3}, - computation_types.StructType([np.int32] * 3), - [1, 2, 3], - ), - ( - 'dict_to_dict', - {'a': 1, 'b': 2, 'c': 3}, - computation_types.StructType([ - ('a', np.int32), - ('b', np.int32), - ('c', np.int32), - ]), - {'a': 1, 'b': 2, 'c': 3}, - ), - ( - 'dict_to_named_tuple', - {'a': 1, 'b': 2, 'c': 3}, - computation_types.StructWithPythonType( - [ - ('a', np.int32), - ('b', np.int32), - ('c', np.int32), - ], - container_type=_TestNamedTuple, - ), - _TestNamedTuple(1, 2, 3), - ), - ( - 'named_tuple_to_list', - _TestNamedTuple(1, 2, 3), - computation_types.StructType([np.int32] * 3), - [1, 2, 3], - ), - ( - 'named_tuple_to_dict', - _TestNamedTuple(1, 2, 3), - computation_types.StructType([ - ('a', np.int32), - ('b', np.int32), - ('c', np.int32), - ]), - {'a': 1, 'b': 2, 'c': 3}, - ), - ( - 'named_tuple_to_named_tuple', - _TestNamedTuple(1, 2, 3), - computation_types.StructWithPythonType( - [ - ('a', np.int32), - ('b', np.int32), - ('c', np.int32), - ], - container_type=_TestNamedTuple, - ), - _TestNamedTuple(1, 2, 3), - ), - ( - 'federated_value', - 1, - computation_types.FederatedType(np.int32, placements.CLIENTS), - 1, - ), - ( - 'federated_value_in_structure', - [1, 2, 3], - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.CLIENTS), - ]), - [1, 2, 3], - ), - ( - 'federated_structure', - [1, 2, 3], - computation_types.FederatedType([np.int32] * 3, placements.CLIENTS), - [1, 2, 3], - ), - ( - 'federated_structure_nested', - [[1, 2], [3]], - computation_types.FederatedType( - [[np.int32] * 2, [np.int32]], placements.CLIENTS - ), - [[1, 2], [3]], - ), - ) - def test_returns_result(self, obj, type_spec, expected): - actual = type_conversions.to_structure_with_type(obj, type_spec) - self.assertEqual(actual, expected) - - @parameterized.named_parameters( - ( - 'wrong_type_spec', - [1, 2, 3], - computation_types.TensorType(np.int32), - ), - ( - 'wrong_type_spec_nested', - [[1, 2], [3]], - computation_types.StructType([np.int32] * 3), - ), - ( - 'partially_named', - [1, 2, 3], - computation_types.StructType([ - ('a', np.int32), - ('b', np.int32), - (None, np.int32), - ]), - ), - ) - def test_raises_value_error(self, obj, type_spec): - with self.assertRaises(ValueError): - type_conversions.to_structure_with_type(obj, type_spec) - - -class TypeToPyContainerTest(absltest.TestCase): - - def test_tuple_passthrough(self): - value = (1, 2.0) - result = type_conversions.type_to_py_container( - (1, 2.0), - computation_types.StructWithPythonType( - [np.int32, np.float32], container_type=list - ), - ) - self.assertEqual(result, value) - - def test_represents_unnamed_fields_as_tuple(self): - input_value = structure.Struct([(None, 1), (None, 2.0)]) - input_type = computation_types.StructType([np.int32, np.float32]) - self.assertEqual( - type_conversions.type_to_py_container(input_value, input_type), (1, 2.0) - ) - - def test_represents_named_fields_as_odict(self): - input_value = structure.Struct([('a', 1), ('b', 2.0)]) - input_type = computation_types.StructType( - [('a', np.int32), ('b', np.float32)] - ) - self.assertEqual( - type_conversions.type_to_py_container(input_value, input_type), - collections.OrderedDict(a=1, b=2.0), - ) - - def test_raises_on_mixed_named_unnamed(self): - input_value = structure.Struct([('a', 1), (None, 2.0)]) - input_type = computation_types.StructType( - [('a', np.int32), (None, np.float32)] - ) - with self.assertRaises(ValueError): - type_conversions.type_to_py_container(input_value, input_type) - - def test_anon_tuple_without_names_to_container_without_names(self): - anon_tuple = structure.Struct([(None, 1), (None, 2.0)]) - types = [np.int32, np.float32] - self.assertSequenceEqual( - type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, list) - ), - [1, 2.0], - ) - self.assertSequenceEqual( - type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, tuple) - ), - (1, 2.0), - ) - - def test_succeeds_with_federated_namedtupletype(self): - anon_tuple = structure.Struct([(None, 1), (None, 2.0)]) - types = [np.int32, np.float32] - self.assertSequenceEqual( - type_conversions.type_to_py_container( - anon_tuple, - computation_types.FederatedType( - computation_types.StructWithPythonType(types, list), - placements.SERVER, - ), - ), - [1, 2.0], - ) - self.assertSequenceEqual( - type_conversions.type_to_py_container( - anon_tuple, - computation_types.FederatedType( - computation_types.StructWithPythonType(types, tuple), - placements.SERVER, - ), - ), - (1, 2.0), - ) - - def test_client_placed_tuple(self): - value = [ - structure.Struct([(None, 1), (None, 2)]), - structure.Struct([(None, 3), (None, 4)]), - ] - type_spec = computation_types.FederatedType( - computation_types.StructWithPythonType( - [(None, np.int32), (None, np.int32)], tuple - ), - placements.CLIENTS, - ) - self.assertEqual( - [(1, 2), (3, 4)], - type_conversions.type_to_py_container(value, type_spec), - ) - - def test_anon_tuple_with_names_to_container_without_names_fails(self): - anon_tuple = structure.Struct([(None, 1), ('a', 2.0)]) - types = [np.int32, np.float32] - with self.assertRaisesRegex( - ValueError, 'Cannot convert value with field name' - ): - type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, tuple) - ) - anon_tuple = structure.Struct([('a', 1), ('b', 2.0)]) - with self.assertRaisesRegex( - ValueError, 'Cannot convert value with field name' - ): - type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, list) - ) - - def test_anon_tuple_with_names_to_container_with_names(self): - anon_tuple = structure.Struct([('a', 1), ('b', 2.0)]) - types = [('a', np.int32), ('b', np.float32)] - self.assertDictEqual( - type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, dict) - ), - {'a': 1, 'b': 2.0}, - ) - self.assertSequenceEqual( - type_conversions.type_to_py_container( - anon_tuple, - computation_types.StructWithPythonType( - types, collections.OrderedDict - ), - ), - collections.OrderedDict([('a', 1), ('b', 2.0)]), - ) - test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b']) - self.assertSequenceEqual( - type_conversions.type_to_py_container( - anon_tuple, - computation_types.StructWithPythonType(types, test_named_tuple), - ), - test_named_tuple(a=1, b=2.0), - ) - - @attrs.define - class TestFoo: - a: int - b: float - - self.assertEqual( - type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, TestFoo) - ), - TestFoo(a=1, b=2.0), - ) - - def test_anon_tuple_without_names_promoted_to_container_with_names(self): - anon_tuple = structure.Struct([(None, 1), (None, 2.0)]) - types = [('a', np.int32), ('b', np.float32)] - dict_converted_value = type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, dict) - ) - odict_converted_value = type_conversions.type_to_py_container( - anon_tuple, - computation_types.StructWithPythonType(types, collections.OrderedDict), - ) - - test_named_tuple = collections.namedtuple('TestNamedTuple', ['a', 'b']) - named_tuple_converted_value = type_conversions.type_to_py_container( - anon_tuple, - computation_types.StructWithPythonType(types, test_named_tuple), - ) - - @attrs.define - class TestFoo: - a: object - b: object - - attr_converted_value = type_conversions.type_to_py_container( - anon_tuple, computation_types.StructWithPythonType(types, TestFoo) - ) - - self.assertIsInstance(dict_converted_value, dict) - self.assertIsInstance(odict_converted_value, collections.OrderedDict) - self.assertIsInstance(named_tuple_converted_value, test_named_tuple) - self.assertIsInstance(attr_converted_value, TestFoo) - - def test_nested_py_containers(self): - anon_tuple = structure.Struct([ - (None, 1), - (None, 2.0), - ( - None, - structure.Struct( - [('a', 3), ('b', structure.Struct([(None, 4), (None, 5)]))] - ), - ), - ]) - - dict_subtype = computation_types.StructWithPythonType( - [ - ('a', np.int32), - ( - 'b', - computation_types.StructWithPythonType( - [np.int32, np.int32], tuple - ), - ), - ], - dict, - ) - type_spec = computation_types.StructType( - [(None, np.int32), (None, np.float32), (None, dict_subtype)] - ) - - expected_nested_structure = (1, 2.0, collections.OrderedDict(a=3, b=(4, 5))) - self.assertEqual( - type_conversions.type_to_py_container(anon_tuple, type_spec), - expected_nested_structure, - ) - - -class TypeToNonAllEqualTest(absltest.TestCase): - - def test_with_bool(self): - for x in [True, False]: - self.assertEqual( - str( - type_conversions.type_to_non_all_equal( - computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=x - ) - ) - ), - '{int32}@CLIENTS', - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/type_factory.py b/tensorflow_federated/python/core/impl/types/type_factory.py deleted file mode 100644 index daba69efcc..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_factory.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# limitations under the License. -"""A library of construction functions for computation type structures.""" - -from tensorflow_federated.python.core.impl.types import computation_types - - -def reduction_op( - result_type_spec: computation_types.Type, - element_type_spec: computation_types.Type, -) -> computation_types.Type: - """Returns the type of a reduction operator of the form `( -> U)`. - - Args: - result_type_spec: A `computation_types.Type`, the result of reduction (`U`). - element_type_spec: A `computation_types.Type`, the type of elements to be - reduced (`T`). - - Returns: - The type of the corresponding reduction operator (`( -> U)`). - """ - return computation_types.FunctionType( - computation_types.StructType([result_type_spec, element_type_spec]), - result_type_spec, - ) - - -def unary_op(type_spec: computation_types.Type) -> computation_types.Type: - """Returns the type of an unary operator that operates on `type_spec`. - - Args: - type_spec: A `computation_types.Type`. - - Returns: - The type of the corresponding unary operator. - """ - return computation_types.FunctionType(type_spec, type_spec) - - -def binary_op(type_spec: computation_types.Type) -> computation_types.Type: - """Returns the type of a binary operator that operates on `type_spec`. - - Args: - type_spec: A `computation_types.Type`. - - Returns: - The type of the corresponding binary operator. - """ - return reduction_op(type_spec, type_spec) diff --git a/tensorflow_federated/python/core/impl/types/type_factory_test.py b/tensorflow_federated/python/core/impl/types/type_factory_test.py deleted file mode 100644 index ca5cf1b864..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_factory_test.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import numpy as np - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_factory - - -class TypeConstructorsTest(absltest.TestCase): - - def test_reduction_op(self): - result_type = computation_types.TensorType(np.float32) - element_type = computation_types.TensorType(np.int32) - actual_type = type_factory.reduction_op(result_type, element_type) - expected_type = computation_types.FunctionType( - computation_types.StructType([result_type, element_type]), result_type - ) - self.assertEqual(actual_type, expected_type) - - def test_unary_op(self): - type_spec = computation_types.TensorType(np.bool_) - actual_type = type_factory.unary_op(type_spec) - expected_type = computation_types.FunctionType(np.bool_, np.bool_) - self.assertEqual(actual_type, expected_type) - - def test_binary_op(self): - type_spec = computation_types.TensorType(np.bool_) - actual_type = type_factory.binary_op(type_spec) - expected_type = computation_types.FunctionType( - computation_types.StructType([type_spec, type_spec]), type_spec - ) - self.assertEqual(actual_type, expected_type) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/type_serialization.py b/tensorflow_federated/python/core/impl/types/type_serialization.py deleted file mode 100644 index 37799d901f..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_serialization.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""A library of (de)serialization functions for computation types.""" - -from collections.abc import Mapping -import weakref - -from tensorflow_federated.proto.v0 import array_pb2 -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import dtype_utils -from tensorflow_federated.python.core.impl.types import placements - - -# Manual cache used rather than `cachetools.cached` due to incompatibility -# with `WeakKeyDictionary`. We want to use a `WeakKeyDictionary` so that -# cache entries are destroyed once the types they index no longer exist. -_type_serialization_cache: Mapping[computation_types.Type, pb.Type] = ( - weakref.WeakKeyDictionary({}) -) - - -def serialize_type(type_spec: computation_types.Type) -> pb.Type: - """Serializes 'type_spec' as a pb.Type. - - Note: Currently only serialization for tensor, named tuple, sequence, and - function types is implemented. - - Args: - type_spec: A `computation_types.Type`. - - Returns: - The corresponding instance of `pb.Type`. - - Raises: - TypeError: if the argument is of the wrong type. - NotImplementedError: for type variants for which serialization is not - implemented. - """ - cached_proto = _type_serialization_cache.get(type_spec) - if cached_proto is not None: - return cached_proto - - if isinstance(type_spec, computation_types.TensorType): - dtype = dtype_utils.to_proto(type_spec.dtype.type) - shape = array_shape.to_proto(type_spec.shape) - proto = pb.Type( - tensor=pb.TensorType( - dtype=dtype, dims=shape.dim, unknown_rank=shape.unknown_rank - ) - ) - elif isinstance(type_spec, computation_types.SequenceType): - proto = pb.Type( - sequence=pb.SequenceType(element=serialize_type(type_spec.element)) - ) - elif isinstance(type_spec, computation_types.StructType): - proto = pb.Type( - struct=pb.StructType( - element=[ - pb.StructType.Element(name=e[0], value=serialize_type(e[1])) - for e in type_spec.items() - ] - ) - ) - elif isinstance(type_spec, computation_types.FunctionType): - if type_spec.parameter is not None: - serialized_parameter = serialize_type(type_spec.parameter) - else: - serialized_parameter = None - proto = pb.Type( - function=pb.FunctionType( - parameter=serialized_parameter, - result=serialize_type(type_spec.result), - ) - ) - elif isinstance(type_spec, computation_types.PlacementType): - proto = pb.Type(placement=pb.PlacementType()) - elif isinstance(type_spec, computation_types.FederatedType): - proto = pb.Type( - federated=pb.FederatedType( - member=serialize_type(type_spec.member), - placement=pb.PlacementSpec( - value=pb.Placement(uri=type_spec.placement.uri) - ), - all_equal=type_spec.all_equal, - ) - ) - else: - raise NotImplementedError - - _type_serialization_cache[type_spec] = proto - return proto - - -def deserialize_type(type_proto: pb.Type) -> computation_types.Type: - """Deserializes 'type_proto' as a `tff.Type`. - - Note: Currently only deserialization for tensor, named tuple, sequence, and - function types is implemented. - - Args: - type_proto: A `pb.Type` to deserialize. - - Returns: - The corresponding instance of `tff.Type`. - - Raises: - TypeError: If the argument is of the wrong type. - NotImplementedError: For type variants for which deserialization is not - implemented. - """ - type_variant = type_proto.WhichOneof('type') - if type_variant == 'tensor': - dtype = dtype_utils.from_proto(type_proto.tensor.dtype) - shape_pb = array_pb2.ArrayShape( - dim=type_proto.tensor.dims, unknown_rank=type_proto.tensor.unknown_rank - ) - shape = array_shape.from_proto(shape_pb) - return computation_types.TensorType(dtype, shape) - elif type_variant == 'sequence': - return computation_types.SequenceType( - deserialize_type(type_proto.sequence.element) - ) - elif type_variant == 'struct': - - def empty_str_to_none(s): - if not s: - return None - return s - - return computation_types.StructType( - [ - (empty_str_to_none(e.name), deserialize_type(e.value)) - for e in type_proto.struct.element - ], - convert=False, - ) - elif type_variant == 'function': - if type_proto.function.HasField('parameter'): - parameter_type = deserialize_type(type_proto.function.parameter) - else: - parameter_type = None - result_type = deserialize_type(type_proto.function.result) - return computation_types.FunctionType( - parameter=parameter_type, result=result_type - ) - elif type_variant == 'placement': - return computation_types.PlacementType() - elif type_variant == 'federated': - placement_oneof = type_proto.federated.placement.WhichOneof('placement') - if placement_oneof == 'value': - return computation_types.FederatedType( - member=deserialize_type(type_proto.federated.member), - placement=placements.uri_to_placement_literal( - type_proto.federated.placement.value.uri - ), - all_equal=type_proto.federated.all_equal, - ) - else: - raise NotImplementedError( - 'Deserialization of federated types with placement spec as {} ' - 'is not currently implemented yet.'.format(placement_oneof) - ) - else: - raise NotImplementedError('Unknown type variant {}.'.format(type_variant)) diff --git a/tensorflow_federated/python/core/impl/types/type_serialization_test.py b/tensorflow_federated/python/core/impl/types/type_serialization_test.py deleted file mode 100644 index 60bc978929..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_serialization_test.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.proto.v0 import data_type_pb2 -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import dtype_utils -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_serialization - - -class TypeSerializationTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('scalar_int', np.int32, []), - ('tensor_int', np.int32, [10, 20]), - ('tensor_undefined_dim_int', np.int32, [None, 10, 20]), - ('scalar_string', np.str_, []), - ('scalar_boo', np.bool_, []), - ) - def test_serialize_tensor_type(self, dtype, shape): - type_signature = computation_types.TensorType(dtype, shape) - actual_proto = type_serialization.serialize_type(type_signature) - dtype = dtype_utils.to_proto(dtype) - shape_pb = array_shape.to_proto(shape) - expected_proto = pb.Type( - tensor=pb.TensorType( - dtype=dtype, - dims=shape_pb.dim, - unknown_rank=shape_pb.unknown_rank, - ) - ) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_type_with_string_sequence(self): - actual_proto = type_serialization.serialize_type( - computation_types.SequenceType(np.str_) - ) - expected_proto = pb.Type( - sequence=pb.SequenceType( - element=pb.Type( - tensor=pb.TensorType(dtype=data_type_pb2.DataType.DT_STRING) - ) - ) - ) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_type_with_tensor_tuple(self): - type_signature = computation_types.StructType([ - ('x', np.int32), - ('y', np.str_), - np.float32, - ('z', np.bool_), - ]) - actual_proto = type_serialization.serialize_type(type_signature) - expected_proto = pb.Type( - struct=pb.StructType( - element=[ - pb.StructType.Element( - name='x', - value=pb.Type( - tensor=pb.TensorType( - dtype=data_type_pb2.DataType.DT_INT32 - ) - ), - ), - pb.StructType.Element( - name='y', - value=pb.Type( - tensor=pb.TensorType( - dtype=data_type_pb2.DataType.DT_STRING - ) - ), - ), - pb.StructType.Element( - value=pb.Type( - tensor=pb.TensorType( - dtype=data_type_pb2.DataType.DT_FLOAT - ) - ) - ), - pb.StructType.Element( - name='z', - value=pb.Type( - tensor=pb.TensorType( - dtype=data_type_pb2.DataType.DT_BOOL - ) - ), - ), - ] - ) - ) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_type_with_nested_tuple(self): - type_signature = computation_types.StructType([ - ('x', [('y', [('z', np.bool_)])]), - ]) - actual_proto = type_serialization.serialize_type(type_signature) - - z_proto = pb.StructType.Element( - name='z', - value=pb.Type( - tensor=pb.TensorType(dtype=data_type_pb2.DataType.DT_BOOL) - ), - ) - expected_proto = pb.Type( - struct=pb.StructType( - element=[ - pb.StructType.Element( - name='x', - value=pb.Type( - struct=pb.StructType( - element=[ - pb.StructType.Element( - name='y', - value=pb.Type( - struct=pb.StructType(element=[z_proto]) - ), - ) - ] - ), - ), - ) - ] - ) - ) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_type_with_function(self): - actual_proto = type_serialization.serialize_type( - computation_types.FunctionType((np.int32, np.int32), np.bool_) - ) - expected_proto = pb.Type( - function=pb.FunctionType( - parameter=pb.Type( - struct=pb.StructType( - element=[ - pb.StructType.Element( - value=pb.Type( - tensor=pb.TensorType( - dtype=data_type_pb2.DataType.DT_INT32 - ) - ) - ), - pb.StructType.Element( - value=pb.Type( - tensor=pb.TensorType( - dtype=data_type_pb2.DataType.DT_INT32 - ) - ) - ), - ] - ) - ), - result=pb.Type( - tensor=pb.TensorType(dtype=data_type_pb2.DataType.DT_BOOL) - ), - ) - ) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_type_with_placement(self): - actual_proto = type_serialization.serialize_type( - computation_types.PlacementType() - ) - expected_proto = pb.Type(placement=pb.PlacementType()) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_type_with_federated_bool(self): - federated_type = computation_types.FederatedType( - np.bool_, placements.CLIENTS, True - ) - actual_proto = type_serialization.serialize_type(federated_type) - expected_proto = pb.Type( - federated=pb.FederatedType( - placement=pb.PlacementSpec( - value=pb.Placement(uri=placements.CLIENTS.uri) - ), - all_equal=True, - member=pb.Type( - tensor=pb.TensorType(dtype=data_type_pb2.DataType.DT_BOOL) - ), - ) - ) - self.assertEqual(actual_proto, expected_proto) - - def test_serialize_deserialize_tensor_types(self): - self._serialize_deserialize_roundtrip_test([ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32, [10]), - computation_types.TensorType(np.int32, [None]), - ]) - - def test_serialize_deserialize_sequence_types(self): - self._serialize_deserialize_roundtrip_test([ - computation_types.SequenceType(np.int32), - computation_types.SequenceType( - computation_types.StructType((np.int32, np.bool_)) - ), - ]) - - def test_serialize_deserialize_named_tuple_types(self): - self._serialize_deserialize_roundtrip_test([ - computation_types.StructType([np.int32, np.bool_]), - computation_types.StructType([ - np.int32, - computation_types.StructType([('x', np.bool_)]), - ]), - computation_types.StructType([('x', np.int32)]), - ]) - - def test_serialize_deserialize_named_tuple_types_py_container(self): - # The Py container is destroyed during ser/de. - with_container = computation_types.StructWithPythonType( - (np.int32, np.bool_), tuple - ) - p1 = type_serialization.serialize_type(with_container) - without_container = type_serialization.deserialize_type(p1) - self.assertNotEqual(with_container, without_container) # Not equal. - self.assertIsInstance(without_container, computation_types.StructType) - self.assertNotIsInstance( - without_container, computation_types.StructWithPythonType - ) - with_container.check_equivalent_to(without_container) - - def test_serialize_deserialize_function_types(self): - self._serialize_deserialize_roundtrip_test([ - computation_types.FunctionType(np.int32, np.bool_), - computation_types.FunctionType(None, np.bool_), - ]) - - def test_serialize_deserialize_placement_type(self): - self._serialize_deserialize_roundtrip_test([ - computation_types.PlacementType(), - ]) - - def test_serialize_deserialize_federated_types(self): - self._serialize_deserialize_roundtrip_test([ - computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=True - ), - computation_types.FederatedType( - np.int32, placements.CLIENTS, all_equal=False - ), - ]) - - def _serialize_deserialize_roundtrip_test(self, type_list): - """Performs roundtrip serialization/deserialization of computation_types. - - Args: - type_list: A list of instances of computation_types.Type or things - convertible to it. - """ - for t1 in type_list: - p1 = type_serialization.serialize_type(t1) - t2 = type_serialization.deserialize_type(p1) - p2 = type_serialization.serialize_type(t2) - self.assertEqual(repr(t1), repr(t2)) - self.assertEqual(repr(p1), repr(p2)) - self.assertTrue(t1.is_equivalent_to(t2)) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/type_test_utils.py b/tensorflow_federated/python/core/impl/types/type_test_utils.py deleted file mode 100644 index d43958f553..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_test_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for testing types.""" - -from tensorflow_federated.python.core.impl.types import computation_types - - -def assert_type_assignable_from(target_type, source_type): - """Asserts that `target_type` is assignable from `source_type`.""" - message = None - try: - target_type.check_assignable_from(source_type) - except computation_types.TypeNotAssignableError as e: - message = e.message - if message is not None: - raise AssertionError(message) - - -def assert_types_equivalent(first_type, second_type): - """Asserts that the types are equivalent.""" - message = None - try: - first_type.check_equivalent_to(second_type) - except computation_types.TypesNotEquivalentError as e: - message = e.message - if message is not None: - raise AssertionError(message) - - -def assert_types_identical(first_type, second_type): - """Asserts that the types are identical.""" - message = None - try: - first_type.check_identical_to(second_type) - except computation_types.TypesNotIdenticalError as e: - message = e.message - if message is not None: - raise AssertionError(message) diff --git a/tensorflow_federated/python/core/impl/types/type_transformations.py b/tensorflow_federated/python/core/impl/types/type_transformations.py deleted file mode 100644 index f0a3f0cd7e..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_transformations.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# limitations under the License. -"""A library of transformation functions for computation types.""" - -from collections.abc import Callable -from typing import TypeVar - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.types import computation_types - -T = TypeVar('T') - - -def strip_placement( - type_signature: computation_types.Type, -) -> computation_types.Type: - """Removes instances of `FederatedType` from `type_signature`.""" - - def _remove_placement(type_signature): - if isinstance(type_signature, computation_types.FederatedType): - return type_signature.member, True - return type_signature, False - - return transform_type_postorder(type_signature, _remove_placement)[0] - - -# TODO: b/134525440 - Unifying the recursive methods in type_analysis. -def transform_type_postorder( - type_signature: computation_types.Type, - transform_fn: Callable[ - [computation_types.Type], tuple[computation_types.Type, bool] - ], -) -> tuple[computation_types.Type, bool]: - """Walks type tree of `type_signature` postorder, calling `transform_fn`. - - Args: - type_signature: Instance of `computation_types.Type` to transform - recursively. - transform_fn: Transformation function to apply to each node in the type tree - of `type_signature`. Must be instance of Python function type. - - Returns: - A possibly transformed version of `type_signature`, with each node in its - tree the result of applying `transform_fn` to the corresponding node in - `type_signature`. - - Raises: - NotImplementedError: If the types don't match the specification above. - """ - py_typecheck.check_type(type_signature, computation_types.Type) - if isinstance(type_signature, computation_types.FederatedType): - transformed_member, member_mutated = transform_type_postorder( - type_signature.member, transform_fn - ) - if member_mutated: - type_signature = computation_types.FederatedType( - transformed_member, type_signature.placement, type_signature.all_equal - ) - type_signature, type_signature_mutated = transform_fn(type_signature) - return type_signature, type_signature_mutated or member_mutated - elif isinstance(type_signature, computation_types.SequenceType): - transformed_element, element_mutated = transform_type_postorder( - type_signature.element, transform_fn - ) - if element_mutated: - type_signature = computation_types.SequenceType(transformed_element) - type_signature, type_signature_mutated = transform_fn(type_signature) - return type_signature, type_signature_mutated or element_mutated - elif isinstance(type_signature, computation_types.FunctionType): - if type_signature.parameter is not None: - transformed_parameter, parameter_mutated = transform_type_postorder( - type_signature.parameter, transform_fn - ) - else: - transformed_parameter, parameter_mutated = (None, False) - transformed_result, result_mutated = transform_type_postorder( - type_signature.result, transform_fn - ) - if parameter_mutated or result_mutated: - type_signature = computation_types.FunctionType( - transformed_parameter, transformed_result - ) - type_signature, type_signature_mutated = transform_fn(type_signature) - return type_signature, ( - type_signature_mutated or parameter_mutated or result_mutated - ) - elif isinstance(type_signature, computation_types.StructType): - elements = [] - elements_mutated = False - for element in type_signature.items(): - transformed_element, element_mutated = transform_type_postorder( - element[1], transform_fn - ) - elements_mutated = elements_mutated or element_mutated - elements.append((element[0], transformed_element)) - if elements_mutated: - if isinstance(type_signature, computation_types.StructWithPythonType): - type_signature = computation_types.StructWithPythonType( - elements, - type_signature.python_container, - ) - else: - type_signature = computation_types.StructType(elements) - type_signature, type_signature_mutated = transform_fn(type_signature) - return type_signature, type_signature_mutated or elements_mutated - elif isinstance( - type_signature, - ( - computation_types.AbstractType, - computation_types.PlacementType, - computation_types.TensorType, - ), - ): - return transform_fn(type_signature) - else: - raise NotImplementedError(f'Unexpected type found: {type_signature}.') - - -# TODO: b/134525440 - Unifying the recursive methods in type_analysis. -def visit_preorder( - type_signature: computation_types.Type, - fn: Callable[[computation_types.Type, T], T], - context: T, -): - """Recursively calls `fn` on the possibly nested structure `type_signature`. - - Walks the tree in a preorder manner. Updates `context` on the way down with - the appropriate information, as defined in `fn`. - - Args: - type_signature: A `computation_types.Type`. - fn: A function to apply to each of the constituent elements of - `type_signature` with the argument `context`. Must return an updated - version of `context` which incorporated the information we'd like to track - as we move down the type tree. - context: Initial state of information to be passed down the tree. - """ - context = fn(type_signature, context) - for child_type in type_signature.children(): - visit_preorder(child_type, fn, context) diff --git a/tensorflow_federated/python/core/impl/types/type_transformations_test.py b/tensorflow_federated/python/core/impl/types/type_transformations_test.py deleted file mode 100644 index 55ed9b4bfb..0000000000 --- a/tensorflow_federated/python/core/impl/types/type_transformations_test.py +++ /dev/null @@ -1,405 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_transformations - - -def _convert_tensor_to_float(type_spec): - if isinstance(type_spec, computation_types.TensorType): - return computation_types.TensorType(np.float32, shape=type_spec.shape), True - return type_spec, False - - -def _convert_abstract_type_to_tensor(type_spec): - if isinstance(type_spec, computation_types.AbstractType): - return computation_types.TensorType(np.float32), True - return type_spec, False - - -def _convert_placement_type_to_tensor(type_spec): - if isinstance(type_spec, computation_types.PlacementType): - return computation_types.TensorType(np.float32), True - return type_spec, False - - -def _convert_function_to_tensor(type_spec): - if isinstance(type_spec, computation_types.FunctionType): - return computation_types.TensorType(np.float32), True - return type_spec, False - - -def _convert_federated_to_tensor(type_spec): - if isinstance(type_spec, computation_types.FederatedType): - return computation_types.TensorType(np.float32), True - return type_spec, False - - -def _convert_sequence_to_tensor(type_spec): - if isinstance(type_spec, computation_types.SequenceType): - return computation_types.TensorType(np.float32), True - return type_spec, False - - -def _convert_tuple_to_tensor(type_spec): - if isinstance(type_spec, computation_types.StructType): - return computation_types.TensorType(np.float32), True - return type_spec, False - - -class StripPlacementTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ( - 'noop_for_non_federated', - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - ), - ( - 'removes_server', - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.TensorType(np.int32), - ), - ( - 'removes_clients', - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.TensorType(np.int32), - ), - ( - 'removes_nested', - computation_types.StructType( - [computation_types.FederatedType(np.int32, placements.SERVER)] - ), - computation_types.StructType([np.int32]), - ), - ( - 'removes_multiple', - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float16, placements.CLIENTS), - ]), - computation_types.StructType([np.int32, np.float16]), - ), - ]) - def test_strips_placement(self, argument, expected): - self.assertEqual(expected, type_transformations.strip_placement(argument)) - - -class TransformTypePostorderTest(absltest.TestCase): - - def test_raises_on_none_type(self): - with self.assertRaises(TypeError): - type_transformations.transform_type_postorder(None, lambda x: x) - - def test_raises_on_none_function(self): - with self.assertRaises(TypeError): - type_transformations.transform_type_postorder( - computation_types.TensorType(np.int32), None - ) - - def test_raises_on_non_type_first_arg(self): - with self.assertRaises(TypeError): - type_transformations.transform_type_postorder(np.int32, None) - - def test_transforms_tensor(self): - orig_type = computation_types.TensorType(np.int32) - expected_type = computation_types.TensorType(np.float32) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_transforms_federated_type(self): - orig_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - expected_type = computation_types.FederatedType( - np.float32, placements.CLIENTS - ) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_recurses_under_federated_type(self): - orig_type = computation_types.FederatedType([np.int32], placements.CLIENTS) - expected_type = computation_types.FederatedType( - [np.float32], placements.CLIENTS - ) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_updates_mutated_bit_at_federated(self): - orig_type = computation_types.FederatedType(np.int32, placements.CLIENTS) - _, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_federated_to_tensor - ) - self.assertTrue(mutated) - - def test_transforms_sequence(self): - orig_type = computation_types.SequenceType(np.int32) - expected_type = computation_types.SequenceType(np.float32) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_recurses_under_sequence(self): - orig_type = computation_types.SequenceType([np.int32]) - expected_type = computation_types.SequenceType([np.float32]) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_updates_mutated_bit_at_sequence(self): - orig_type = computation_types.SequenceType(np.int32) - _, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_sequence_to_tensor - ) - self.assertTrue(mutated) - - def test_transforms_function(self): - orig_type = computation_types.FunctionType(np.int32, np.int64) - expected_type = computation_types.FunctionType(np.float32, np.float32) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_recurses_under_function(self): - orig_type = computation_types.FunctionType([np.int32], np.int64) - expected_type = computation_types.FunctionType([np.float32], np.float32) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_updates_mutated_bit_at_function(self): - orig_type = computation_types.FunctionType(np.int32, np.int64) - _, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_function_to_tensor - ) - self.assertTrue(mutated) - - def test_transforms_unnamed_tuple_type_preserving_tuple_container(self): - orig_type = computation_types.StructWithPythonType( - [np.int32, np.float64], tuple - ) - expected_type = computation_types.StructWithPythonType( - [np.float32, np.float32], tuple - ) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_transforms_unnamed_tuple_type(self): - orig_type = computation_types.StructType([np.int32, np.float64]) - expected_type = computation_types.StructType([np.float32, np.float32]) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_updates_mutated_bit_at_tuple(self): - orig_type = computation_types.StructType([np.int32, np.float64]) - _, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tuple_to_tensor - ) - self.assertTrue(mutated) - - def test_transforms_named_tuple_type(self): - orig_type = computation_types.StructType( - [('a', np.int32), ('b', np.float64)] - ) - expected_type = computation_types.StructType( - [('a', np.float32), ('b', np.float32)] - ) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_recurses_under_named_tuple_type(self): - orig_type = computation_types.StructType( - [[('a', np.int32), ('b', np.float64)]] - ) - expected_type = computation_types.StructType( - [[('a', np.float32), ('b', np.float32)]] - ) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_transforms_named_tuple_type_preserving_tuple_container(self): - orig_type = computation_types.StructWithPythonType( - [('a', np.int32), ('b', np.float64)], dict - ) - expected_type = computation_types.StructWithPythonType( - [('a', np.float32), ('b', np.float32)], dict - ) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_tensor_to_float - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_transforms_abstract_type(self): - orig_type = computation_types.AbstractType('T') - expected_type = computation_types.TensorType(np.float32) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_placement_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - def test_transforms_placement_type(self): - orig_type = computation_types.PlacementType() - expected_type = computation_types.TensorType(np.float32) - result_type, mutated = type_transformations.transform_type_postorder( - orig_type, _convert_placement_type_to_tensor - ) - noop_type, not_mutated = type_transformations.transform_type_postorder( - orig_type, _convert_abstract_type_to_tensor - ) - self.assertEqual(result_type, expected_type) - self.assertEqual(noop_type, orig_type) - self.assertTrue(mutated) - self.assertFalse(not_mutated) - - -class VisitPreorderTest(parameterized.TestCase): - - @parameterized.named_parameters([ - ('abstract_type', computation_types.AbstractType('T'), 1), - ( - 'nested_function_type', - computation_types.FunctionType( - computation_types.FunctionType( - computation_types.FunctionType(np.int32, np.int32), np.int32 - ), - np.int32, - ), - 7, - ), - ( - 'named_tuple_type', - computation_types.StructType([ - np.int32, - np.bool_, - computation_types.SequenceType(np.int32), - ]), - 5, - ), - ('placement_type', computation_types.PlacementType(), 1), - ]) - def test_preorder_call_count(self, type_signature, expected_count): - class Counter: - k = 0 - - def _count_hits(given_type, arg): - del given_type # Unused. - Counter.k += 1 - return arg - - type_transformations.visit_preorder(type_signature, _count_hits, None) - actual_count = Counter.k - self.assertEqual(actual_count, expected_count) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/core/impl/types/typed_object.py b/tensorflow_federated/python/core/impl/types/typed_object.py deleted file mode 100644 index 9833bec37d..0000000000 --- a/tensorflow_federated/python/core/impl/types/typed_object.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2018, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines an abstract interface for things that possess TFF type signatures.""" - -import abc - -from tensorflow_federated.python.core.impl.types import computation_types - - -class TypedObject(metaclass=abc.ABCMeta): - """An abstract interface for things that possess TFF type signatures.""" - - @property - @abc.abstractmethod - def type_signature(self) -> computation_types.Type: - """Returns the TFF type of this object (an instance of `tff.Type`).""" - raise NotImplementedError diff --git a/tensorflow_federated/python/core/templates/BUILD b/tensorflow_federated/python/core/templates/BUILD index 4f00359cd0..52fe510a34 100644 --- a/tensorflow_federated/python/core/templates/BUILD +++ b/tensorflow_federated/python/core/templates/BUILD @@ -35,9 +35,7 @@ py_library( ":errors", ":measured_process", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -49,10 +47,7 @@ py_test( ":aggregation_process", ":errors", ":measured_process", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -68,8 +63,7 @@ py_library( ":errors", ":iterative_process", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", + "@federated_language//federated_language", ], ) @@ -80,10 +74,7 @@ py_test( deps = [ ":errors", ":estimation_process", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -93,9 +84,7 @@ py_library( deps = [ ":errors", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", + "@federated_language//federated_language", ], ) @@ -106,10 +95,7 @@ py_test( deps = [ ":errors", ":iterative_process", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -119,11 +105,7 @@ py_library( deps = [ ":errors", ":iterative_process", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -154,9 +136,6 @@ py_test( ":iterative_process", ":measured_process", "//tensorflow_federated/python/core/impl/compiler:compiler_test_utils", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/core/templates/aggregation_process.py b/tensorflow_federated/python/core/templates/aggregation_process.py index c8d4739889..98e9cf3384 100644 --- a/tensorflow_federated/python/core/templates/aggregation_process.py +++ b/tensorflow_federated/python/core/templates/aggregation_process.py @@ -13,10 +13,8 @@ # limitations under the License. """Defines a template for a stateful process that aggregates values.""" +import federated_language from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process @@ -73,8 +71,8 @@ class AggregationProcess(measured_process.MeasuredProcess): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, ): """Creates a `tff.templates.AggregationProcess`. @@ -112,7 +110,7 @@ def __init__( super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not isinstance( - initialize_fn.type_signature.result, computation_types.FederatedType + initialize_fn.type_signature.result, federated_language.FederatedType ): raise AggregationNotFederatedError( 'Provided `initialize_fn` must return a federated type, but found ' @@ -126,7 +124,7 @@ def __init__( non_federated_types = [ t for t in next_types - if not isinstance(t, computation_types.FederatedType) + if not isinstance(t, federated_language.FederatedType) ] if non_federated_types: offending_types_str = '\n- '.join(str(t) for t in non_federated_types) @@ -137,7 +135,10 @@ def __init__( f'The non-federated types are:\n {offending_types_str}.' ) - if initialize_fn.type_signature.result.placement != placements.SERVER: + if ( + initialize_fn.type_signature.result.placement + != federated_language.SERVER + ): raise AggregationPlacementError( 'The state controlled by an `AggregationProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' @@ -154,25 +155,28 @@ def __init__( f'the following input type: {next_fn_param}.' ) - if next_fn_param[_INPUT_PARAM_INDEX].placement != placements.CLIENTS: + if ( + next_fn_param[_INPUT_PARAM_INDEX].placement + != federated_language.CLIENTS + ): raise AggregationPlacementError( 'The second input argument of `next_fn` must be placed at CLIENTS ' f'but found {next_fn_param[_INPUT_PARAM_INDEX]}.' ) - if next_fn_result.result.placement != placements.SERVER: + if next_fn_result.result.placement != federated_language.SERVER: raise AggregationPlacementError( 'The "result" attribute of return type of `next_fn` must be placed ' f'at SERVER, but found {next_fn_result.result}.' ) - if next_fn_result.measurements.placement != placements.SERVER: + if next_fn_result.measurements.placement != federated_language.SERVER: raise AggregationPlacementError( 'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.' ) @property - def next(self) -> computation_base.Computation: + def next(self) -> federated_language.framework.Computation: """A `tff.Computation` that runs one iteration of the process. Its first argument should always be the current state (originally produced diff --git a/tensorflow_federated/python/core/templates/aggregation_process_test.py b/tensorflow_federated/python/core/templates/aggregation_process_test.py index f90113d628..6c103cf083 100644 --- a/tensorflow_federated/python/core/templates/aggregation_process_test.py +++ b/tensorflow_federated/python/core/templates/aggregation_process_test.py @@ -15,12 +15,9 @@ import collections from absl.testing import absltest +import federated_language import numpy as np -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process @@ -38,23 +35,23 @@ def _server_zero(): """Returns zero integer placed at SERVER.""" - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _initialize(): return _server_zero() -@federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), +@federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.float32, federated_language.CLIENTS), ) def _next(state, value): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(value), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(value), + federated_language.federated_value(1, federated_language.SERVER), ) @@ -68,19 +65,21 @@ def test_construction_does_not_raise(self): def test_construction_with_empty_state_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_empty(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def _next_empty(state, value): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(value), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(value), + federated_language.federated_value(1, federated_language.SERVER), ) try: @@ -90,24 +89,26 @@ def _next_empty(state, value): def test_construction_with_unknown_dimension_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_unknown(): - return intrinsics.federated_value( - np.array([], np.str_), placements.SERVER + return federated_language.federated_value( + np.array([], np.str_), federated_language.SERVER ) - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.TensorType(np.str_, [None]), - placements.SERVER, + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.TensorType(np.str_, [None]), + federated_language.SERVER, + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS ), - computation_types.FederatedType(np.float32, placements.CLIENTS), ) def next_fn(strings, value): return measured_process.MeasuredProcessOutput( strings, - intrinsics.federated_sum(value), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(value), + federated_language.federated_value(1, federated_language.SERVER), ) try: @@ -120,9 +121,11 @@ def next_fn(strings, value): def test_construction_with_value_type_mismatch_does_not_raise(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def next_fn(state, value): del value # Unused. @@ -155,8 +158,8 @@ def test_next_not_tff_computation_raises(self): def test_init_param_not_empty_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _initialize_arg(x): return x @@ -166,24 +169,26 @@ def _initialize_arg(x): def test_init_state_not_assignable(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_float(): - return intrinsics.federated_value(0.0, placements.SERVER) + return federated_language.federated_value(0.0, federated_language.SERVER) with self.assertRaises(errors.TemplateStateNotAssignableError): aggregation_process.AggregationProcess(_initialize_float, _next) def test_next_state_not_assignable(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def _next_float(state, value): del state # Unused. return measured_process.MeasuredProcessOutput( - intrinsics.federated_value(0.0, placements.SERVER), - intrinsics.federated_sum(value), + federated_language.federated_value(0.0, federated_language.SERVER), + federated_language.federated_sum(value), _server_zero(), ) @@ -191,19 +196,23 @@ def _next_float(state, value): aggregation_process.AggregationProcess(_initialize, _next_float) def test_measured_process_output_as_state_raises(self): - no_value = lambda: intrinsics.federated_value((), placements.SERVER) + no_value = lambda: federated_language.federated_value( + (), federated_language.SERVER + ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def initialize_fn(): - return intrinsics.federated_zip( + return federated_language.federated_zip( measured_process.MeasuredProcessOutput( no_value(), no_value(), no_value() ) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def next_fn(state, value): del state, value @@ -216,12 +225,14 @@ def next_fn(state, value): def test_next_return_tuple_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def tuple_next_fn(state, value): - return state, intrinsics.federated_sum(value), _server_zero() + return state, federated_language.federated_sum(value), _server_zero() with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): aggregation_process.AggregationProcess(_initialize, tuple_next_fn) @@ -231,13 +242,15 @@ def test_next_return_namedtuple_raises(self): 'MeasuredProcessOutput', ['state', 'result', 'measurements'] ) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def namedtuple_next_fn(state, value): return measured_process_output( - state, intrinsics.federated_sum(value), _server_zero() + state, federated_language.federated_sum(value), _server_zero() ) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): @@ -245,14 +258,16 @@ def namedtuple_next_fn(state, value): def test_next_return_odict_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def odict_next_fn(state, value): return collections.OrderedDict( state=state, - result=intrinsics.federated_sum(value), + result=federated_language.federated_sum(value), measurements=_server_zero(), ) @@ -261,14 +276,16 @@ def odict_next_fn(state, value): def test_federated_measured_process_output_raises(self): # Using federated_zip to place FederatedType at the top of the hierarchy. - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def next_fn(state, value): - return intrinsics.federated_zip( + return federated_language.federated_zip( measured_process.MeasuredProcessOutput( - state, intrinsics.federated_sum(value), _server_zero() + state, federated_language.federated_sum(value), _server_zero() ) ) @@ -282,11 +299,11 @@ def next_fn(state, value): def test_non_federated_init_next_raises(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_unplaced(): return 0 - @federated_computation.federated_computation(np.int32, np.float32) + @federated_language.federated_computation(np.int32, np.float32) def _next_unplaced(state, value): return measured_process.MeasuredProcessOutput(state, value, ()) @@ -297,22 +314,28 @@ def _next_unplaced(state, value): def test_init_tuple_of_federated_types_raises(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_tuple(): return (_server_zero(), _server_zero()) - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.SERVER), + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ]), - computation_types.FederatedType(np.float32, placements.CLIENTS), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def _next_tuple(state, value): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(value), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(value), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(aggregation_process.AggregationNotFederatedError): @@ -320,19 +343,21 @@ def _next_tuple(state, value): def test_non_server_placed_init_state_raises(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_clients(): - return intrinsics.federated_value(0, placements.CLIENTS) + return federated_language.federated_value(0, federated_language.CLIENTS) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def _next_non_server(state, value): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(value), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(value), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(aggregation_process.AggregationPlacementError): @@ -342,14 +367,14 @@ def _next_non_server(state, value): def test_single_param_next_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) def _next_single_parameter(state): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_value(1.0, placements.SERVER), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1.0, federated_language.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(errors.TemplateNextFnNumArgsError): @@ -359,15 +384,15 @@ def _next_single_parameter(state): def test_non_clients_placed_next_value_param_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.SERVER), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.float32, federated_language.SERVER), ) def _next_non_clients(state, value): return measured_process.MeasuredProcessOutput( state, value, - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(aggregation_process.AggregationPlacementError): @@ -375,15 +400,15 @@ def _next_non_clients(state, value): def test_non_server_placed_next_result_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.SERVER), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.float32, federated_language.SERVER), ) def _next_non_server_result(state, value): return measured_process.MeasuredProcessOutput( state, value, - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(aggregation_process.AggregationPlacementError): @@ -393,13 +418,13 @@ def _next_non_server_result(state, value): def test_non_server_placed_next_measurements_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) def next_fn(state, value): return measured_process.MeasuredProcessOutput( - state, intrinsics.federated_sum(value), value + state, federated_language.federated_sum(value), value ) with self.assertRaises(aggregation_process.AggregationPlacementError): @@ -409,17 +434,21 @@ def test_is_weighted_property(self): process = aggregation_process.AggregationProcess(_initialize, _next) self.assertFalse(process.is_weighted) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def weighted_next_fn(state, value, weight): del weight return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(value), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(value), + federated_language.federated_value(1, federated_language.SERVER), ) process = aggregation_process.AggregationProcess( diff --git a/tensorflow_federated/python/core/templates/estimation_process.py b/tensorflow_federated/python/core/templates/estimation_process.py index 4cba29a283..7843ead4f4 100644 --- a/tensorflow_federated/python/core/templates/estimation_process.py +++ b/tensorflow_federated/python/core/templates/estimation_process.py @@ -15,9 +15,9 @@ from typing import Optional +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import iterative_process @@ -43,9 +43,9 @@ class EstimationProcess(iterative_process.IterativeProcess): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, - report_fn: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, + report_fn: federated_language.framework.Computation, next_is_multi_arg: Optional[bool] = None, ): """Creates a `tff.templates.EstimationProcess`. @@ -76,7 +76,7 @@ def __init__( initialize_fn, next_fn, next_is_multi_arg=next_is_multi_arg ) - py_typecheck.check_type(report_fn, computation_base.Computation) + py_typecheck.check_type(report_fn, federated_language.framework.Computation) report_fn_arg_type = report_fn.type_signature.parameter if report_fn_arg_type is None or not report_fn_arg_type.is_assignable_from( self.state_type @@ -92,7 +92,7 @@ def __init__( self._report_fn = report_fn @property - def report(self) -> computation_base.Computation: + def report(self) -> federated_language.framework.Computation: """A `tff.Computation` that computes the current estimate from `state`. Given a `state` controlled by this process, computes and returns the most @@ -103,7 +103,7 @@ def report(self) -> computation_base.Computation: """ return self._report_fn - def map(self, map_fn: computation_base.Computation): + def map(self, map_fn: federated_language.framework.Computation): """Applies `map_fn` to the estimate function of the process. This method will return a new instance of `EstimationProcess` with the same @@ -121,7 +121,7 @@ def map(self, map_fn: computation_base.Computation): EstimateNotAssignableError: If the return type of `report` is not assignable to the expected input type of `map_fn`. """ - py_typecheck.check_type(map_fn, computation_base.Computation) + py_typecheck.check_type(map_fn, federated_language.framework.Computation) estimate_type = self.report.type_signature.result map_fn_arg_type = map_fn.type_signature.parameter @@ -136,7 +136,7 @@ def map(self, map_fn: computation_base.Computation): f'and the argument of `map_fn` is:\n{map_fn_arg_type}' ) - transformed_report_fn = federated_computation.federated_computation( + transformed_report_fn = federated_language.federated_computation( lambda state: map_fn(self.report(state)), self.state_type ) diff --git a/tensorflow_federated/python/core/templates/estimation_process_test.py b/tensorflow_federated/python/core/templates/estimation_process_test.py index 098434013e..d68b09436f 100644 --- a/tensorflow_federated/python/core/templates/estimation_process_test.py +++ b/tensorflow_federated/python/core/templates/estimation_process_test.py @@ -13,12 +13,8 @@ # limitations under the License. from absl.testing import absltest +import federated_language import numpy as np - -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import estimation_process @@ -29,28 +25,28 @@ ) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _initialize(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) -@federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) +@federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _next(state): return state -@federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) +@federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _report(state): del state # Unused. - return intrinsics.federated_value(1.0, placements.SERVER) + return federated_language.federated_value(1.0, federated_language.SERVER) -@federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.SERVER) +@federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.SERVER) ) def _map(state): return (state, state) @@ -66,18 +62,18 @@ def test_construction_does_not_raise(self): def test_construction_with_empty_state_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_empty(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER) ) def _next_empty(state): return (state, 1.0) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER) ) def _report_empty(state): return state @@ -91,23 +87,25 @@ def _report_empty(state): def test_construction_with_unknown_dimension_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_unknown(): - return intrinsics.federated_value( - np.array([], np.str_), placements.SERVER + return federated_language.federated_value( + np.array([], np.str_), federated_language.SERVER ) - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.TensorType(np.str_, [None]), placements.SERVER + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.TensorType(np.str_, [None]), + federated_language.SERVER, ) ) def _next_unknown(state): return state - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.TensorType(np.str_, [None]), placements.SERVER + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.TensorType(np.str_, [None]), + federated_language.SERVER, ) ) def _report_unknown(state): @@ -149,8 +147,8 @@ def _report_py(x): def test_init_param_not_empty_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _initialize_arg(x): return x @@ -160,17 +158,17 @@ def _initialize_arg(x): def test_init_state_not_assignable(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_float(): - return intrinsics.federated_value(0.0, placements.SERVER) + return federated_language.federated_value(0.0, federated_language.SERVER) with self.assertRaises(errors.TemplateStateNotAssignableError): estimation_process.EstimationProcess(_initialize_float, _next, _report) def test_next_state_not_assignable(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.SERVER) ) def _next_float(state): return state @@ -179,10 +177,15 @@ def _next_float(state): estimation_process.EstimationProcess(_initialize, _next_float, _report) def test_next_state_not_assignable_tuple_result(self): - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.float32, placements.SERVER), - computation_types.FederatedType(np.float32, placements.SERVER), + + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ]), ) def _next_float(state, value): @@ -195,7 +198,7 @@ def _next_float(state, value): def test_report_state_not_assignable(self): - @federated_computation.federated_computation(np.float32) + @federated_language.federated_computation(np.float32) def _report_float(state): return state @@ -220,7 +223,7 @@ def test_mapped_process_as_expected(self): def test_map_estimate_not_assignable(self): - @federated_computation.federated_computation(np.int32) + @federated_language.federated_computation(np.int32) def _map_int(x): return x diff --git a/tensorflow_federated/python/core/templates/iterative_process.py b/tensorflow_federated/python/core/templates/iterative_process.py index 8b3cd32ed9..6efe84e4cc 100644 --- a/tensorflow_federated/python/core/templates/iterative_process.py +++ b/tensorflow_federated/python/core/templates/iterative_process.py @@ -15,16 +15,15 @@ from typing import Optional +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import errors def _is_nonempty_struct(type_signature) -> bool: return ( - isinstance(type_signature, computation_types.StructType) + isinstance(type_signature, federated_language.StructType) and type_signature ) @@ -120,8 +119,8 @@ def next_fn(state, round_num): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, next_is_multi_arg: Optional[bool] = None, ): """Creates a `tff.templates.IterativeProcess`. @@ -148,7 +147,9 @@ def __init__( `initialize_fn` or `next_fn` is not assignable to the first input argument of `next_fn`. """ - py_typecheck.check_type(initialize_fn, computation_base.Computation) + py_typecheck.check_type( + initialize_fn, federated_language.framework.Computation + ) if initialize_fn.type_signature.parameter is not None: raise errors.TemplateInitFnParamNotEmptyError( 'Provided `initialize_fn` must be a no-arg function, but found ' @@ -156,7 +157,7 @@ def __init__( ) initialize_result_type = initialize_fn.type_signature.result - py_typecheck.check_type(next_fn, computation_base.Computation) + py_typecheck.check_type(next_fn, federated_language.framework.Computation) next_parameter_type = next_fn.type_signature.parameter state_type = _infer_state_type( initialize_result_type, next_parameter_type, next_is_multi_arg @@ -186,12 +187,12 @@ def __init__( self._next_fn = next_fn @property - def initialize(self) -> computation_base.Computation: + def initialize(self) -> federated_language.framework.Computation: """A no-arg `tff.Computation` that returns the initial state.""" return self._initialize_fn @property - def next(self) -> computation_base.Computation: + def next(self) -> federated_language.framework.Computation: """A `tff.Computation` that produces the next state. Its first argument should always be the current state (originally produced @@ -204,7 +205,7 @@ def next(self) -> computation_base.Computation: return self._next_fn @property - def state_type(self) -> computation_types.Type: + def state_type(self) -> federated_language.Type: """The `tff.Type` of the state of the process.""" return self._state_type # pytype: disable=bad-return-type @@ -227,8 +228,8 @@ def is_stateful(process: IterativeProcess) -> bool: contains types other than `tff.types.StructType`, `False` otherwise. """ state_type = process.state_type - if isinstance(state_type, computation_types.FederatedType): + if isinstance(state_type, federated_language.FederatedType): state_type = state_type.member - return not type_analysis.contains_only( - state_type, lambda t: isinstance(t, computation_types.StructType) + return not federated_language.framework.type_contains_only( + state_type, lambda t: isinstance(t, federated_language.StructType) ) diff --git a/tensorflow_federated/python/core/templates/iterative_process_test.py b/tensorflow_federated/python/core/templates/iterative_process_test.py index 6b31f9aed9..555fc2ed55 100644 --- a/tensorflow_federated/python/core/templates/iterative_process_test.py +++ b/tensorflow_federated/python/core/templates/iterative_process_test.py @@ -14,12 +14,8 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np - -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import iterative_process @@ -31,13 +27,13 @@ ) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _initialize(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) -@federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) +@federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _next(state): return state @@ -53,12 +49,12 @@ def test_construction_does_not_raise(self): def test_construction_with_empty_state_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_empty(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER) ) def _next_empty(state): return state @@ -70,15 +66,16 @@ def _next_empty(state): def test_construction_with_unknown_dimension_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_unknown(): - return intrinsics.federated_value( - np.array([], np.str_), placements.SERVER + return federated_language.federated_value( + np.array([], np.str_), federated_language.SERVER ) - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.TensorType(np.str_, [None]), placements.SERVER + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.TensorType(np.str_, [None]), + federated_language.SERVER, ) ) def _next_unknown(state): @@ -104,8 +101,8 @@ def test_next_not_tff_computation_raises(self): def test_init_param_not_empty_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _initialize_arg(x): return x @@ -114,17 +111,18 @@ def _initialize_arg(x): iterative_process.IterativeProcess(_initialize_arg, _next) def test_init_state_not_assignable(self): - @federated_computation.federated_computation() + + @federated_language.federated_computation() def _initialize_float(): - return intrinsics.federated_value(0.0, placements.SERVER) + return federated_language.federated_value(0.0, federated_language.SERVER) with self.assertRaises(errors.TemplateStateNotAssignableError): iterative_process.IterativeProcess(_initialize_float, _next) def test_next_state_not_assignable(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.float32, federated_language.SERVER) ) def _next_float(state): return state @@ -133,10 +131,15 @@ def _next_float(state): iterative_process.IterativeProcess(_initialize, _next_float) def test_next_state_not_assignable_tuple_result(self): - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.float32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ]), ) def _next_float(state, value): @@ -147,16 +150,16 @@ def _next_float(state, value): def _create_test_process( - state_type: computation_types.Type, state: object + state_type: federated_language.Type, state: object ) -> iterative_process.IterativeProcess: - @federated_computation.federated_computation + @federated_language.federated_computation def _init_process(): - if isinstance(state_type, computation_types.FederatedType): - return intrinsics.federated_value(state, state_type.placement) + if isinstance(state_type, federated_language.FederatedType): + return federated_language.federated_value(state, state_type.placement) else: return state - @federated_computation.federated_computation(state_type, np.int32) + @federated_language.federated_computation(state_type, np.int32) def _next_process(state, value): return state, value @@ -166,25 +169,27 @@ def _next_process(state, value): class HasEmptyStateTest(parameterized.TestCase, absltest.TestCase): @parameterized.named_parameters( - ('struct_tuple_empty', computation_types.StructType([]), ()), + ('struct_tuple_empty', federated_language.StructType([]), ()), ( 'struct_list_empty', - computation_types.StructWithPythonType([], list), + federated_language.StructWithPythonType([], list), [], ), ( 'struct_nested_empty', - computation_types.StructType([[], [[]]]), + federated_language.StructType([[], [[]]]), ((), ((),)), ), ( 'federated_struct_empty', - computation_types.FederatedType([], placements.SERVER), + federated_language.FederatedType([], federated_language.SERVER), (), ), ( 'federated_struct_nested_empty', - computation_types.FederatedType([[], [[]]], placements.SERVER), + federated_language.FederatedType( + [[], [[]]], federated_language.SERVER + ), ((), ((),)), ), ) @@ -193,31 +198,31 @@ def test_is_stateful_returns_false(self, state_type, state): self.assertFalse(iterative_process.is_stateful(process)) @parameterized.named_parameters( - ('tensor', computation_types.TensorType(np.int32), 1), + ('tensor', federated_language.TensorType(np.int32), 1), ( 'struct_tuple_tensor', - computation_types.StructType([np.int32]), + federated_language.StructType([np.int32]), (1,), ), ( 'struct_list_tensor', - computation_types.StructWithPythonType([np.int32], list), + federated_language.StructWithPythonType([np.int32], list), [1], ), ( 'struct_nested_tensor', - computation_types.StructType([[], [[np.int32]]]), + federated_language.StructType([[], [[np.int32]]]), ((), ((1,),)), ), ( 'federated_tensor', - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), 1, ), ( 'federated_struct_nested_tensor', - computation_types.FederatedType( - [[], [[np.int32]]], placements.SERVER + federated_language.FederatedType( + [[], [[np.int32]]], federated_language.SERVER ), ((), ((1,),)), ), diff --git a/tensorflow_federated/python/core/templates/measured_process.py b/tensorflow_federated/python/core/templates/measured_process.py index b6713f8c20..5a81ac8dab 100644 --- a/tensorflow_federated/python/core/templates/measured_process.py +++ b/tensorflow_federated/python/core/templates/measured_process.py @@ -16,13 +16,9 @@ import collections from typing import NamedTuple, Optional +import federated_language import tree -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import iterative_process @@ -65,8 +61,8 @@ class MeasuredProcess(iterative_process.IterativeProcess): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, next_is_multi_arg: Optional[bool] = None, ): """Creates a `tff.templates.MeasuredProcess`. @@ -99,7 +95,7 @@ def __init__( super().__init__(initialize_fn, next_fn, next_is_multi_arg) next_result_type = next_fn.type_signature.result if not ( - isinstance(next_result_type, computation_types.StructWithPythonType) + isinstance(next_result_type, federated_language.StructWithPythonType) and next_result_type.python_container is MeasuredProcessOutput ): raise errors.TemplateNotMeasuredProcessOutputError( @@ -166,10 +162,10 @@ def chain_measured_processes( """ # Concatenate all the initialization computations. - @federated_computation.federated_computation + @federated_language.federated_computation def composition_initialize(): try: - return intrinsics.federated_zip( + return federated_language.federated_zip( collections.OrderedDict( (name, process.initialize()) for name, process in measured_processes.items() @@ -187,15 +183,15 @@ def composition_initialize(): first_process = next(iter(measured_processes.values())) first_process_value_type_spec = first_process.next.type_signature.parameter[1] # pytype: disable=unsupported-operands - concatenated_state_type_spec = computation_types.FederatedType( - computation_types.StructType([ + concatenated_state_type_spec = federated_language.FederatedType( + federated_language.StructType([ (name, process.next.type_signature.parameter[0].member) # pytype: disable=unsupported-operands for name, process in measured_processes.items() ]), - placements.SERVER, + federated_language.SERVER, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( concatenated_state_type_spec, first_process_value_type_spec ) def composition_next(state, values): @@ -217,9 +213,9 @@ def composition_next(state, values): measurements[name] = output.measurements values = output.result return MeasuredProcessOutput( - state=intrinsics.federated_zip(new_states), + state=federated_language.federated_zip(new_states), result=values, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return MeasuredProcess(composition_initialize, composition_next) @@ -265,10 +261,10 @@ def concatenate_measured_processes( """ # Concatenate all the initialization computations. - @federated_computation.federated_computation + @federated_language.federated_computation def concatenation_initialize(): try: - return intrinsics.federated_zip( + return federated_language.federated_zip( collections.OrderedDict( (name, process.initialize()) for name, process in measured_processes.items() @@ -284,11 +280,11 @@ def concatenation_initialize(): f'placement of the state: {state_type}.' ) from e - concatenated_state_type_spec = computation_types.FederatedType( + concatenated_state_type_spec = federated_language.FederatedType( tree.map_structure( lambda process: process.state_type.member, measured_processes ), - placements.SERVER, + federated_language.SERVER, ) concatenated_values_type_spec = tree.map_structure( lambda process: process.next.type_signature.parameter[1], @@ -296,7 +292,7 @@ def concatenation_initialize(): ) # Concatenate all the next computations. - @federated_computation.federated_computation( + @federated_language.federated_computation( concatenated_state_type_spec, concatenated_values_type_spec ) def concatenation_next(state, values): @@ -309,9 +305,9 @@ def concatenation_next(state, values): results[name] = output.result measurements[name] = output.measurements return MeasuredProcessOutput( - state=intrinsics.federated_zip(new_states), + state=federated_language.federated_zip(new_states), result=results, - measurements=intrinsics.federated_zip(measurements), + measurements=federated_language.federated_zip(measurements), ) return MeasuredProcess(concatenation_initialize, concatenation_next) diff --git a/tensorflow_federated/python/core/templates/measured_process_test.py b/tensorflow_federated/python/core/templates/measured_process_test.py index fb8f78e20b..228594e063 100644 --- a/tensorflow_federated/python/core/templates/measured_process_test.py +++ b/tensorflow_federated/python/core/templates/measured_process_test.py @@ -16,13 +16,10 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.core.impl.compiler import compiler_test_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import iterative_process @@ -36,15 +33,15 @@ ) -@federated_computation.federated_computation() +@federated_language.federated_computation() def _initialize(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) -@federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), +@federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) ) def _next(state, value): @@ -61,14 +58,16 @@ def test_construction_does_not_raise(self): def test_construction_with_empty_state_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_empty(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ]) ) def _next_empty(state, value): @@ -81,18 +80,21 @@ def _next_empty(state, value): def test_construction_with_unknown_dimension_does_not_raise(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_unknown(): - return intrinsics.federated_value( - np.array([], np.str_), placements.SERVER + return federated_language.federated_value( + np.array([], np.str_), federated_language.SERVER ) - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType( - computation_types.TensorType(np.str_, [None]), placements.SERVER + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + federated_language.TensorType(np.str_, [None]), + federated_language.SERVER, + ), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS ), - computation_types.FederatedType(np.int32, placements.CLIENTS), ]) ) def _next_unknown(state, value): @@ -124,8 +126,8 @@ def _next_py(state, value): def test_init_param_not_empty_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def _initialize_arg(x): return x @@ -135,19 +137,23 @@ def _initialize_arg(x): def test_init_state_not_assignable(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_float(): - return intrinsics.federated_value(0.0, placements.SERVER) + return federated_language.federated_value(0.0, federated_language.SERVER) with self.assertRaises(errors.TemplateStateNotAssignableError): measured_process.MeasuredProcess(_initialize_float, _next) def test_next_state_not_assignable(self): - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.float32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ]) ) def _next_float(state, value): @@ -160,17 +166,19 @@ def _next_float(state, value): def test_measured_process_output_as_state_raises(self): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_process(): value = measured_process.MeasuredProcessOutput((), (), ()) - return intrinsics.federated_value(value, placements.SERVER) + return federated_language.federated_value( + value, federated_language.SERVER + ) - @federated_computation.federated_computation( - computation_types.StructWithPythonType( + @federated_language.federated_computation( + federated_language.StructWithPythonType( elements=[[ - computation_types.StructType([]), - computation_types.StructType([]), - computation_types.StructType([]), + federated_language.StructType([]), + federated_language.StructType([]), + federated_language.StructType([]), ]], container_type=measured_process.MeasuredProcessOutput, ) @@ -183,8 +191,8 @@ def _next_process(state): def test_next_return_tensor_type_raises(self): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), ) def _next_tensor(state): return state @@ -194,10 +202,14 @@ def _next_tensor(state): def test_next_return_tuple_raises(self): - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ]) ) def _next_tuple(state, value): @@ -211,10 +223,14 @@ def test_next_return_namedtuple_raises(self): 'MeasuredProcessOutput', ['state', 'result', 'measurements'] ) - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ]) ) def _next_named_tuple(state, value): @@ -225,10 +241,14 @@ def _next_named_tuple(state, value): def test_next_return_odict_raises(self): - @federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + np.int32, federated_language.SERVER + ), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ]) ) def _next_odict(state, value): @@ -242,16 +262,18 @@ def _next_odict(state, value): measured_process.MeasuredProcess(_initialize, _next_odict) def test_federated_measured_process_output_raises(self): - initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value(0, placements.SERVER) + initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value(0, federated_language.SERVER) + ) + empty = lambda: federated_language.federated_value( + (), federated_language.SERVER ) - empty = lambda: intrinsics.federated_value((), placements.SERVER) state_type = initialize_fn.type_signature.result # Using federated_zip to place FederatedType at the top of the hierarchy. - @federated_computation.federated_computation(state_type) + @federated_language.federated_computation(state_type) def next_fn(state): - return intrinsics.federated_zip( + return federated_language.federated_zip( measured_process.MeasuredProcessOutput(state, empty(), empty()) ) @@ -264,24 +286,26 @@ def next_fn(state): def _create_test_measured_process_double(state_type, state_init, values_type): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_double(): - return intrinsics.federated_value(state_init, placements.SERVER) + return federated_language.federated_value( + state_init, federated_language.SERVER + ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def _double(x): return x - @federated_computation.federated_computation( - computation_types.FederatedType(state_type, placements.SERVER), - computation_types.FederatedType(values_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(state_type, federated_language.SERVER), + federated_language.FederatedType(values_type, federated_language.CLIENTS), ) def _next_double(state, values): return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_map(_double, state), - result=intrinsics.federated_map(_double, values), - measurements=intrinsics.federated_value( - collections.OrderedDict(a=1), placements.SERVER + state=federated_language.federated_map(_double, state), + result=federated_language.federated_map(_double, values), + measurements=federated_language.federated_value( + collections.OrderedDict(a=1), federated_language.SERVER ), ) @@ -290,24 +314,26 @@ def _next_double(state, values): def _create_test_measured_process_sum(state_type, state_init, values_type): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_sum(): - return intrinsics.federated_value(state_init, placements.SERVER) + return federated_language.federated_value( + state_init, federated_language.SERVER + ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def _sum(x): return x - @federated_computation.federated_computation( - computation_types.FederatedType(state_type, placements.SERVER), - computation_types.FederatedType(values_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(state_type, federated_language.SERVER), + federated_language.FederatedType(values_type, federated_language.CLIENTS), ) def _next_sum(state, values): return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_map(_sum, state), - result=intrinsics.federated_sum(values), - measurements=intrinsics.federated_value( - collections.OrderedDict(b=2), placements.SERVER + state=federated_language.federated_map(_sum, state), + result=federated_language.federated_sum(values), + measurements=federated_language.federated_value( + collections.OrderedDict(b=2), federated_language.SERVER ), ) @@ -316,20 +342,22 @@ def _next_sum(state, values): def _create_test_measured_process_state_at_clients(): - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType(np.int32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ) def next_fn(state, values): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(values), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(values), + federated_language.federated_value(1, federated_language.SERVER), ) return measured_process.MeasuredProcess( - initialize_fn=federated_computation.federated_computation( - lambda: intrinsics.federated_value(0, placements.CLIENTS) + initialize_fn=federated_language.federated_computation( + lambda: federated_language.federated_value( + 0, federated_language.CLIENTS + ) ), next_fn=next_fn, ) @@ -337,12 +365,12 @@ def next_fn(state, values): def _create_test_measured_process_state_missing_placement(): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_unplaced(): return 0 - @federated_computation.federated_computation( - computation_types.StructType([np.int32, np.int32]) + @federated_language.federated_computation( + federated_language.StructType([np.int32, np.int32]) ) def _next_unplaced(state, value): return measured_process.MeasuredProcessOutput(state, value, ()) @@ -352,20 +380,22 @@ def _next_unplaced(state, value): def _create_test_aggregation_process(state_type, state_init, values_type): - @federated_computation.federated_computation( - computation_types.FederatedType(state_type, placements.SERVER), - computation_types.FederatedType(values_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType(state_type, federated_language.SERVER), + federated_language.FederatedType(values_type, federated_language.CLIENTS), ) def next_fn(state, values): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_sum(values), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_sum(values), + federated_language.federated_value(1, federated_language.SERVER), ) return aggregation_process.AggregationProcess( - initialize_fn=federated_computation.federated_computation( - lambda: intrinsics.federated_value(state_init, placements.SERVER) + initialize_fn=federated_language.federated_computation( + lambda: federated_language.federated_value( + state_init, federated_language.SERVER + ) ), next_fn=next_fn, ) @@ -373,11 +403,11 @@ def next_fn(state, values): def _create_test_iterative_process(state_type, state_init): - @federated_computation.federated_computation() + @federated_language.federated_computation() def _initialize_ip(): return state_init - @federated_computation.federated_computation(state_type) + @federated_language.federated_computation(state_type) def _next_ip(state): return state @@ -404,11 +434,11 @@ def test_composition_type_properties(self, last_process): ) self.assertIsInstance(composite_process, measured_process.MeasuredProcess) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( collections.OrderedDict(double=state_type, last_process=state_type), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -417,20 +447,20 @@ def test_composition_type_properties(self, last_process): ) ) - param_value_type = computation_types.FederatedType( - values_type, placements.CLIENTS + param_value_type = federated_language.FederatedType( + values_type, federated_language.CLIENTS ) - result_value_type = computation_types.FederatedType( - values_type, placements.SERVER + result_value_type = federated_language.FederatedType( + values_type, federated_language.SERVER ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( double=collections.OrderedDict(a=np.int32), last_process=last_process.next.type_signature.result.measurements.member, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, values=param_value_type ), @@ -564,11 +594,11 @@ def test_concatenation_type_properties(self, last_process): concatenated_process, measured_process.MeasuredProcess ) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( collections.OrderedDict(double=state_type, last_process=state_type), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -578,25 +608,29 @@ def test_concatenation_type_properties(self, last_process): ) param_value_type = collections.OrderedDict( - double=computation_types.FederatedType(values_type, placements.CLIENTS), - last_process=computation_types.FederatedType( - values_type, placements.CLIENTS + double=federated_language.FederatedType( + values_type, federated_language.CLIENTS + ), + last_process=federated_language.FederatedType( + values_type, federated_language.CLIENTS ), ) result_value_type = collections.OrderedDict( - double=computation_types.FederatedType(values_type, placements.CLIENTS), - last_process=computation_types.FederatedType( - values_type, placements.SERVER + double=federated_language.FederatedType( + values_type, federated_language.CLIENTS + ), + last_process=federated_language.FederatedType( + values_type, federated_language.SERVER ), ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( double=collections.OrderedDict(a=np.int32), last_process=last_process.next.type_signature.result.measurements.member, ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, values=param_value_type ), diff --git a/tensorflow_federated/python/core/test/BUILD b/tensorflow_federated/python/core/test/BUILD index 259c6cad36..fc0b86276e 100644 --- a/tensorflow_federated/python/core/test/BUILD +++ b/tensorflow_federated/python/core/test/BUILD @@ -1,4 +1,4 @@ -load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@rules_python//python:defs.bzl", "py_library") package( default_applicable_licenses = ["//:package_license"], @@ -19,34 +19,5 @@ py_library( name = "test", srcs = ["__init__.py"], visibility = ["//tensorflow_federated:__pkg__"], - deps = [ - ":static_assert", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_test_utils", - "//tensorflow_federated/python/core/impl/context_stack:runtime_error_context", - "//tensorflow_federated/python/core/impl/context_stack:set_default_context", - "//tensorflow_federated/python/core/impl/types:type_test_utils", - ], -) - -py_library( - name = "static_assert", - srcs = ["static_assert.py"], - deps = [ - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/compiler:building_blocks", - "//tensorflow_federated/python/core/impl/compiler:tree_analysis", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - ], -) - -py_test( - name = "static_assert_test", - size = "small", - srcs = ["static_assert_test.py"], - deps = [ - ":static_assert", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:placements", - ], + deps = ["@federated_language//federated_language"], ) diff --git a/tensorflow_federated/python/core/test/__init__.py b/tensorflow_federated/python/core/test/__init__.py index 12fea807d3..92ee66e710 100644 --- a/tensorflow_federated/python/core/test/__init__.py +++ b/tensorflow_federated/python/core/test/__init__.py @@ -12,15 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. """Libraries for testing TensorFlow Federated.""" +import federated_language -from tensorflow_federated.python.core.impl.context_stack.context_stack_test_utils import with_context -from tensorflow_federated.python.core.impl.context_stack.context_stack_test_utils import with_contexts -from tensorflow_federated.python.core.impl.context_stack.runtime_error_context import create_runtime_error_context -from tensorflow_federated.python.core.impl.context_stack.set_default_context import set_no_default_context -from tensorflow_federated.python.core.impl.types.type_test_utils import assert_type_assignable_from -from tensorflow_federated.python.core.impl.types.type_test_utils import assert_types_equivalent -from tensorflow_federated.python.core.impl.types.type_test_utils import assert_types_identical -from tensorflow_federated.python.core.test.static_assert import assert_contains_secure_aggregation -from tensorflow_federated.python.core.test.static_assert import assert_contains_unsecure_aggregation -from tensorflow_federated.python.core.test.static_assert import assert_not_contains_secure_aggregation -from tensorflow_federated.python.core.test.static_assert import assert_not_contains_unsecure_aggregation +with_context = federated_language.framework.with_context +with_contexts = federated_language.framework.with_contexts +create_runtime_error_context = ( + federated_language.framework.create_runtime_error_context +) +set_no_default_context = federated_language.framework.set_no_default_context +assert_type_assignable_from = ( + federated_language.framework.assert_type_assignable_from +) +assert_types_equivalent = federated_language.framework.assert_types_equivalent +assert_types_identical = federated_language.framework.assert_types_identical +assert_contains_secure_aggregation = ( + federated_language.framework.assert_contains_secure_aggregation +) +assert_contains_unsecure_aggregation = ( + federated_language.framework.assert_contains_unsecure_aggregation +) +assert_not_contains_secure_aggregation = ( + federated_language.framework.assert_not_contains_secure_aggregation +) +assert_not_contains_unsecure_aggregation = ( + federated_language.framework.assert_not_contains_unsecure_aggregation +) diff --git a/tensorflow_federated/python/core/test/static_assert.py b/tensorflow_federated/python/core/test/static_assert.py deleted file mode 100644 index 60ea39ec56..0000000000 --- a/tensorflow_federated/python/core/test/static_assert.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Classes/functions for statically asserting properties of TFF computations.""" - -from typing import Optional - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.compiler import building_blocks -from tensorflow_federated.python.core.impl.compiler import tree_analysis -from tensorflow_federated.python.core.impl.computation import computation_impl - - -def _raise_expected_none( - calls: list[building_blocks.Call], kind: str -) -> Optional[str]: - if not calls: - raise AssertionError('Expected `calls` to not be empty.') - msg = 'Expected no {} aggregations, found {}:'.format(kind, len(calls)) - msg += ''.join(('\n\t' + call.compact_representation() for call in calls)) - raise AssertionError(msg) - - -def assert_contains_secure_aggregation(comp): - """Asserts that `comp` contains at least one secure aggregation call. - - Args: - comp: A `tff.Computation`, often a function annotated with - `tff.federated_computation` or `tff.tensorflow.computation`. Note that - polymorphic functions (those without the types of their arguments - explicitly specified) will not yet be `tff.Computation`s. - - Raises: - AssertionError if `comp` does not contain a secure aggregation call. - ValueError if `comp` contains a call whose target function cannot be - identified. This may result from calls to references or other - indirect structures. - """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) - comp = comp.to_building_block() - calls = tree_analysis.find_secure_aggregation_in_tree(comp) - if not calls: - raise AssertionError( - 'Expected secure aggregation, but none were found in: {}'.format( - comp.compact_representation() - ) - ) - - -def assert_not_contains_secure_aggregation(comp): - """Asserts that `comp` contains no secure aggregation calls. - - Args: - comp: A `tff.Computation`, often a function annotated with - `tff.federated_computation` or `tff.tensorflow.computation`. Note that - polymorphic functions (those without the types of their arguments - explicitly specified) will not yet be `tff.Computation`s. - - Raises: - AssertionError if `comp` contains a secure aggregation call. - ValueError if `comp` contains a call whose target function cannot be - identified. This may result from calls to references or other - indirect structures. - """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) - comp = comp.to_building_block() - calls = tree_analysis.find_secure_aggregation_in_tree(comp) - if calls: - _raise_expected_none(calls, 'secure') - - -def assert_contains_unsecure_aggregation(comp): - """Asserts that `comp` contains at least one unsecure aggregation call. - - Args: - comp: A `tff.Computation`, often a function annotated with - `tff.federated_computation` or `tff.tensorflow.computation`. Note that - polymorphic functions (those without the types of their arguments - explicitly specified) will not yet be `tff.Computation`s. - - Raises: - AssertionError if `comp` does not contain an unsecure aggregation call. - ValueError if `comp` contains a call whose target function cannot be - identified. This may result from calls to references or other - indirect structures. - """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) - comp = comp.to_building_block() - calls = tree_analysis.find_unsecure_aggregation_in_tree(comp) - if not calls: - raise AssertionError( - 'Expected unsecure aggregation, but none were found in:\n{}'.format( - comp.compact_representation() - ) - ) - - -def assert_not_contains_unsecure_aggregation(comp): - """Asserts that `comp` contains no unsecure aggregation calls. - - Args: - comp: A `tff.Computation`, often a function annotated with - `tff.federated_computation` or `tff.tensorflow.computation`. Note that - polymorphic functions (those without the types of their arguments - explicitly specified) will not yet be `tff.Computation`s. - - Raises: - AssertionError if `comp` contains an unsecure aggregation call. - ValueError if `comp` contains a call whose target function cannot be - identified. This may result from calls to references or other - indirect structures. - """ - py_typecheck.check_type(comp, computation_impl.ConcreteComputation) - comp = comp.to_building_block() - calls = tree_analysis.find_unsecure_aggregation_in_tree(comp) - if calls: - _raise_expected_none(calls, 'unsecure') diff --git a/tensorflow_federated/python/core/test/static_assert_test.py b/tensorflow_federated/python/core/test/static_assert_test.py deleted file mode 100644 index cb486116fd..0000000000 --- a/tensorflow_federated/python/core/test/static_assert_test.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2020, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.test import static_assert - - -@federated_computation.federated_computation -def no_aggregation(): - return () - - -@federated_computation.federated_computation -def secure_aggregation(): - data_at_clients = intrinsics.federated_value(1, placements.CLIENTS) - bitwidth = 1 - return intrinsics.federated_secure_sum_bitwidth(data_at_clients, bitwidth) - - -@federated_computation.federated_computation -def unsecure_aggregation(): - data_at_clients = intrinsics.federated_value(1, placements.CLIENTS) - return intrinsics.federated_sum(data_at_clients) - - -@federated_computation.federated_computation -def secure_and_unsecure_aggregation(): - return (secure_aggregation, unsecure_aggregation) - - -class AssertContainsSecAggTest(absltest.TestCase): - - def test_fails_on_noagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_contains_secure_aggregation(no_aggregation) - - def test_passes_on_secagg(self): - static_assert.assert_contains_secure_aggregation(secure_aggregation) - - def test_fails_on_unsecagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_contains_secure_aggregation(unsecure_aggregation) - - def test_passes_on_bothagg(self): - static_assert.assert_contains_secure_aggregation( - secure_and_unsecure_aggregation - ) - - -class AssertNotContainsSecAggTest(absltest.TestCase): - - def test_passes_on_noagg(self): - static_assert.assert_not_contains_secure_aggregation(no_aggregation) - - def test_fails_on_secagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_not_contains_secure_aggregation(secure_aggregation) - - def test_passes_on_unsecagg(self): - static_assert.assert_not_contains_secure_aggregation(unsecure_aggregation) - - def test_fails_on_bothagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_not_contains_secure_aggregation( - secure_and_unsecure_aggregation - ) - - -class AssertContainsUnsecAggTest(absltest.TestCase): - - def test_fails_on_noagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_contains_unsecure_aggregation(no_aggregation) - - def test_fails_on_secagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_contains_unsecure_aggregation(secure_aggregation) - - def test_passes_on_unsecagg(self): - static_assert.assert_contains_unsecure_aggregation(unsecure_aggregation) - - def test_passes_on_bothagg(self): - static_assert.assert_contains_unsecure_aggregation( - secure_and_unsecure_aggregation - ) - - -class AssertNotContainsUnsecAggTest(absltest.TestCase): - - def test_passes_on_noagg(self): - static_assert.assert_not_contains_unsecure_aggregation(no_aggregation) - - def test_passes_on_secagg(self): - static_assert.assert_not_contains_unsecure_aggregation(secure_aggregation) - - def test_fails_on_unsecagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_not_contains_unsecure_aggregation( - unsecure_aggregation - ) - - def test_fails_on_bothagg(self): - with self.assertRaises(AssertionError): - static_assert.assert_not_contains_unsecure_aggregation( - secure_and_unsecure_aggregation - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/learning/BUILD b/tensorflow_federated/python/learning/BUILD index 92498e5d1f..db20591bee 100644 --- a/tensorflow_federated/python/learning/BUILD +++ b/tensorflow_federated/python/learning/BUILD @@ -81,14 +81,9 @@ py_test( ":model_update_aggregator", "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/core/backends/mapreduce:form_utils", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:iterative_process", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -99,8 +94,7 @@ py_library( "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/aggregators:measurements", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -111,9 +105,7 @@ py_test( ":debug_measurements", "//tensorflow_federated/python/aggregators:mean", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/algorithms/BUILD b/tensorflow_federated/python/learning/algorithms/BUILD index 66933fdfc3..30dea8071e 100644 --- a/tensorflow_federated/python/learning/algorithms/BUILD +++ b/tensorflow_federated/python/learning/algorithms/BUILD @@ -43,7 +43,6 @@ py_library( "//tensorflow_federated/python/aggregators:mean", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -58,6 +57,7 @@ py_library( "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", "//tensorflow_federated/python/learning/templates:model_delta_client_work", + "@federated_language//federated_language", ], ) @@ -69,13 +69,13 @@ py_cpu_gpu_test( deps = [ ":fed_avg", "//tensorflow_federated/python/aggregators:factory_utils", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) @@ -89,10 +89,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", @@ -108,6 +104,7 @@ py_library( "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", "//tensorflow_federated/python/learning/templates:model_delta_client_work", + "@federated_language//federated_language", ], ) @@ -118,13 +115,13 @@ py_cpu_gpu_test( shard_count = 10, deps = [ ":fed_avg_with_optimizer_schedule", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) @@ -137,7 +134,6 @@ py_library( "//tensorflow_federated/python/aggregators:mean", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -152,6 +148,7 @@ py_library( "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", "//tensorflow_federated/python/learning/templates:proximal_client_work", + "@federated_language//federated_language", ], ) @@ -164,7 +161,6 @@ py_cpu_gpu_test( ":fed_prox", "//tensorflow_federated/python/aggregators:factory_utils", "//tensorflow_federated/python/core/templates:iterative_process", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -173,6 +169,7 @@ py_cpu_gpu_test( "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", "//tensorflow_federated/python/learning/templates:distributors", + "@federated_language//federated_language", ], ) @@ -186,11 +183,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:tensor_utils", @@ -205,6 +197,7 @@ py_library( "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -222,10 +215,6 @@ py_test( "//tensorflow_federated/python/aggregators:sum_factory", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/core/templates:measured_process", @@ -237,6 +226,7 @@ py_test( "//tensorflow_federated/python/learning/optimizers:sgdm", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", + "@federated_language//federated_language", ], ) @@ -249,10 +239,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/metrics:keras_finalizer", @@ -266,6 +252,7 @@ py_library( "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:finalizers", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -276,11 +263,6 @@ py_test( ":fed_recon_eval", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/metrics:counters", "//tensorflow_federated/python/learning/models:reconstruction_model", @@ -288,6 +270,7 @@ py_test( "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -300,10 +283,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:tensor_utils", @@ -319,6 +298,7 @@ py_library( "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -330,7 +310,6 @@ py_cpu_gpu_test( deps = [ ":fed_sgd", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -338,6 +317,7 @@ py_cpu_gpu_test( "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", + "@federated_language//federated_language", ], ) @@ -349,16 +329,13 @@ py_library( "//tensorflow_federated/python/aggregators:factory_utils", "//tensorflow_federated/python/aggregators:sum_factory", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/templates:client_works", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:finalizers", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -368,8 +345,7 @@ py_test( deps = [ ":kmeans_clustering", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -384,10 +360,6 @@ py_library( "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", @@ -404,6 +376,7 @@ py_library( "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -423,13 +396,8 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", @@ -448,6 +416,7 @@ py_cpu_gpu_test( "//tensorflow_federated/python/learning/optimizers:yogi", "//tensorflow_federated/python/learning/templates:client_works", "//tensorflow_federated/python/learning/templates:distributors", + "@federated_language//federated_language", ], ) @@ -459,11 +428,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:loop_builder", @@ -476,6 +440,7 @@ py_library( "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:finalizers", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -489,11 +454,6 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -507,6 +467,7 @@ py_cpu_gpu_test( "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -519,14 +480,9 @@ py_library( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:array_shape", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:variable", + "@federated_language//federated_language", ], ) @@ -537,10 +493,10 @@ py_cpu_gpu_test( deps = [ ":personalization_eval", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning/models:keras_utils", "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:model_weights", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg.py b/tensorflow_federated/python/learning/algorithms/fed_avg.py index a5d59a9a94..7d19f720df 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg.py @@ -31,6 +31,7 @@ from collections.abc import Callable from typing import Optional, Union +import federated_language import numpy as np import tensorflow as tf @@ -39,7 +40,6 @@ from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator @@ -218,7 +218,7 @@ def initial_model_weights_fn(): else: model_update_type = model_weights_type.trainable aggregator = model_aggregator.create( - model_update_type, computation_types.TensorType(np.float32) + model_update_type, federated_language.TensorType(np.float32) ) process_signature = aggregator.next.type_signature diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_test.py b/tensorflow_federated/python/learning/algorithms/fed_avg_test.py index bdfe8acd09..abac087885 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_test.py @@ -16,9 +16,9 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language from tensorflow_federated.python.aggregators import factory_utils -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_avg @@ -135,7 +135,7 @@ def test_weighted_fed_avg_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) @@ -149,7 +149,7 @@ def test_unweighted_fed_avg_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) @@ -182,7 +182,7 @@ def test_weighted_fed_avg_with_only_secure_aggregation(self, constructor): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py index d3eb012c92..37ec82bddd 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py @@ -18,6 +18,7 @@ import functools from typing import Optional, Union +import federated_language import numpy as np import tensorflow as tf @@ -26,10 +27,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder @@ -95,7 +92,7 @@ def build_scheduled_client_work( if isinstance(model_fn, functional.FunctionalModel): whimsy_model = model_fn weights_type = tf.nest.map_structure( - lambda arr: computation_types.TensorType( + lambda arr: federated_language.TensorType( dtype=arr.dtype, shape=arr.shape ), model_weights.ModelWeights(*whimsy_model.initial_weights), @@ -117,7 +114,7 @@ def build_scheduled_client_work( f'tff.learning.optimizers.Optimizer, got {type(optimizer_fn)=}' ) element_type = tensorflow_types.to_type(whimsy_model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) if isinstance(model_fn, functional.FunctionalModel): build_client_update_fn = ( @@ -145,9 +142,9 @@ def client_update_computation(initial_model_weights, dataset, round_num): ) return client_update(optimizer, initial_model_weights, dataset) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) @tensorflow_computation.tf_computation(np.int32) @tf.function @@ -159,23 +156,25 @@ def add_one(x): def tf_learning_rate_fn(x): return learning_rate_fn(x) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): - round_num_at_clients = intrinsics.federated_broadcast(state) + round_num_at_clients = federated_language.federated_broadcast(state) # We also compute the learning rate at the server, in order to expose the # measurement to the user. - learning_rate = intrinsics.federated_map(tf_learning_rate_fn, state) + learning_rate = federated_language.federated_map(tf_learning_rate_fn, state) # TODO: b/268530457 - Determine if we can broadcast the learning rate. - client_result, model_outputs = intrinsics.federated_map( + client_result, model_outputs = federated_language.federated_map( client_update_computation, (weights, client_data, round_num_at_clients) ) - updated_state = intrinsics.federated_map(add_one, state) + updated_state = federated_language.federated_map(add_one, state) train_metrics = metrics_aggregation_fn(model_outputs) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict( train=train_metrics, client_learning_rate=learning_rate ) @@ -307,7 +306,7 @@ def initial_model_weights_fn(): model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) aggregator = model_aggregator.create( - model_weights_type.trainable, computation_types.TensorType(np.float32) + model_weights_type.trainable, federated_language.TensorType(np.float32) ) process_signature = aggregator.next.type_signature input_client_value_type = process_signature.parameter[1] # pytype: disable=unsupported-operands diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py index d881551901..69d63dbff9 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py @@ -16,9 +16,9 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import tensorflow as tf -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_avg_with_optimizer_schedule @@ -154,7 +154,7 @@ def test_construction_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_eval.py b/tensorflow_federated/python/learning/algorithms/fed_eval.py index 2fed401ea8..bde2eda78d 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_eval.py +++ b/tensorflow_federated/python/learning/algorithms/fed_eval.py @@ -17,17 +17,13 @@ from collections.abc import Callable, Mapping from typing import Optional, Union +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import loop_builder @@ -47,10 +43,10 @@ def _build_local_evaluation( model_fn: Callable[[], variable.VariableModel], - model_weights_type: computation_types.StructType, - batch_type: computation_types.Type, + model_weights_type: federated_language.StructType, + batch_type: federated_language.Type, loop_implementation: loop_builder.LoopImplementation, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Builds the local TFF computation for evaluation of the given model. This produces an unplaced function that evaluates a @@ -82,7 +78,7 @@ def _build_local_evaluation( """ @tensorflow_computation.tf_computation( - model_weights_type, computation_types.SequenceType(batch_type) + model_weights_type, federated_language.SequenceType(batch_type) ) @tf.function def client_eval(incoming_model_weights, dataset): @@ -122,11 +118,11 @@ def reduce_fn(num_examples, batch): def _build_functional_local_evaluation( model: functional.FunctionalModel, - model_weights_type: computation_types.StructType, + model_weights_type: federated_language.StructType, batch_type: Union[ - computation_types.StructType, computation_types.TensorType + federated_language.StructType, federated_language.TensorType ], -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates client evaluation logic for a functional model. This produces an unplaced function that evaluates a @@ -154,7 +150,7 @@ def _build_functional_local_evaluation( """ @tensorflow_computation.tf_computation( - model_weights_type, computation_types.SequenceType(batch_type) + model_weights_type, federated_language.SequenceType(batch_type) ) @tf.function def local_eval(weights, dataset): @@ -192,7 +188,7 @@ def local_eval(weights, dataset): def _build_fed_eval_client_work( model_fn: Callable[[], variable.VariableModel], metrics_aggregation_process: Optional[_AggregationProcess], - model_weights_type: computation_types.StructType, + model_weights_type: federated_language.StructType, loop_implementation: loop_builder.LoopImplementation, ) -> client_works.ClientWorkProcess: """Builds a `ClientWorkProcess` that performs model evaluation at clients.""" @@ -213,7 +209,7 @@ def _tensor_type_from_tensor_like(x): if metrics_aggregation_process is None: # TODO: b/319261270 - Avoid the need for inferring types here, if possible. metrics_finalizers = model.metric_finalizers() - unfinalized_metrics_type = computation_types.StructWithPythonType( + unfinalized_metrics_type = federated_language.StructWithPythonType( unfinalized_metrics_spec, collections.OrderedDict ) factory = sum_aggregation_factory.SumThenFinalizeFactory(metrics_finalizers) @@ -225,7 +221,7 @@ def _tensor_type_from_tensor_like(x): 'metrics_aggregation_process', ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): return metrics_aggregation_process.initialize() @@ -236,22 +232,25 @@ def init_fn(): loop_implementation=loop_implementation, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(model_weights_type, placements.CLIENTS), - computation_types.FederatedType( - computation_types.SequenceType(batch_type), placements.CLIENTS + federated_language.FederatedType( + model_weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + federated_language.SequenceType(batch_type), + federated_language.CLIENTS, ), ) def next_fn(state, model_weights, client_data): - model_outputs = intrinsics.federated_map( + model_outputs = federated_language.federated_map( client_update_computation, (model_weights, client_data) ) metrics_output = metrics_aggregation_process.next( state, model_outputs.local_outputs ) current_round_metrics, total_rounds_metrics = metrics_output.result - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict( eval=collections.OrderedDict( current_round_metrics=current_round_metrics, @@ -260,9 +259,9 @@ def next_fn(state, model_weights, client_data): ) ) # Return empty result as no model update will be performed for evaluation. - empty_client_result = intrinsics.federated_value( + empty_client_result = federated_language.federated_value( client_works.ClientResult(update=(), update_weight=()), - placements.CLIENTS, + federated_language.CLIENTS, ) return measured_process.MeasuredProcessOutput( metrics_output.state, empty_client_result, measurements @@ -274,7 +273,7 @@ def next_fn(state, model_weights, client_data): def _build_functional_fed_eval_client_work( model: functional.FunctionalModel, metrics_aggregation_process: Optional[_AggregationProcess], - model_weights_type: computation_types.StructType, + model_weights_type: federated_language.StructType, ) -> client_works.ClientWorkProcess: """Builds a `ClientWorkProcess` that performs model evaluation at clients.""" @@ -304,34 +303,37 @@ def ndarray_to_tensorspec(ndarray): ).create(unfinalized_metrics_type) ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): return metrics_aggregation_process.initialize() @tensorflow_computation.tf_computation( - model_weights_type, computation_types.SequenceType(batch_type) + model_weights_type, federated_language.SequenceType(batch_type) ) def client_update_computation(model_weights, client_data): # Switch to the tuple expected by FunctionalModel. tuple_weights = (model_weights.trainable, model_weights.non_trainable) return local_eval(tuple_weights, client_data) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(model_weights_type, placements.CLIENTS), - computation_types.FederatedType( - computation_types.SequenceType(batch_type), placements.CLIENTS + federated_language.FederatedType( + model_weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + federated_language.SequenceType(batch_type), + federated_language.CLIENTS, ), ) def next_fn(state, model_weights, client_data): - unfinalized_metrics = intrinsics.federated_map( + unfinalized_metrics = federated_language.federated_map( client_update_computation, (model_weights, client_data) ) metrics_output = metrics_aggregation_process.next( state, unfinalized_metrics ) current_round_metrics, total_rounds_metrics = metrics_output.result - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict( eval=collections.OrderedDict( current_round_metrics=current_round_metrics, @@ -340,9 +342,9 @@ def next_fn(state, model_weights, client_data): ) ) # Return empty result as no model update will be performed for evaluation. - empty_client_result = intrinsics.federated_value( + empty_client_result = federated_language.federated_value( client_works.ClientResult(update=(), update_weight=()), - placements.CLIENTS, + federated_language.CLIENTS, ) return measured_process.MeasuredProcessOutput( metrics_output.state, empty_client_result, measurements @@ -470,8 +472,9 @@ def initial_model_weights_fn(): loop_implementation=loop_implementation, ) - client_work_result_type = computation_types.FederatedType( - client_works.ClientResult(update=(), update_weight=()), placements.CLIENTS + client_work_result_type = federated_language.FederatedType( + client_works.ClientResult(update=(), update_weight=()), + federated_language.CLIENTS, ) model_update_type = client_work_result_type.member.update # pytype: disable=attribute-error model_update_weight_type = client_work_result_type.member.update_weight # pytype: disable=attribute-error diff --git a/tensorflow_federated/python/learning/algorithms/fed_eval_test.py b/tensorflow_federated/python/learning/algorithms/fed_eval_test.py index 4c425c7fa1..f6943ce43f 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_eval_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_eval_test.py @@ -18,6 +18,7 @@ from unittest import mock from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf @@ -26,11 +27,6 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.algorithms import fed_eval @@ -48,11 +44,11 @@ # Convenience aliases. -FederatedType = computation_types.FederatedType -FunctionType = computation_types.FunctionType -SequenceType = computation_types.SequenceType -StructType = computation_types.StructType -TensorType = computation_types.TensorType +FederatedType = federated_language.FederatedType +FunctionType = federated_language.FunctionType +SequenceType = federated_language.SequenceType +StructType = federated_language.StructType +TensorType = federated_language.TensorType class TestModel(variable.VariableModel): @@ -124,7 +120,7 @@ def _tensor_spec_from_tensor_like(x): return tensorflow_types.to_type((x_as_tensor.dtype, x_as_tensor.shape)) finalizer_spec = tf.nest.map_structure(_tensor_spec_from_tensor_like, metrics) - return computation_types.StructWithPythonType( + return federated_language.StructWithPythonType( finalizer_spec, collections.OrderedDict ) @@ -141,9 +137,11 @@ def create_all_zero_state(): local_unfinalized_metrics_type, ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_eval(create_all_zero_state, placements.SERVER) + return federated_language.federated_eval( + create_all_zero_state, federated_language.SERVER + ) @tensorflow_computation.tf_computation( local_unfinalized_metrics_type, local_unfinalized_metrics_type @@ -155,16 +153,18 @@ def get_max_unfinalized_metrics( tf.math.maximum, unfinalized_metrics, new_max_unfinalized_metrics ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ), ) def next_fn(state, unfinalized_metrics): - max_unfinalized_metrics = intrinsics.federated_max(unfinalized_metrics) + max_unfinalized_metrics = federated_language.federated_max( + unfinalized_metrics + ) - state = intrinsics.federated_map( + state = federated_language.federated_map( get_max_unfinalized_metrics, (state, max_unfinalized_metrics) ) @@ -177,19 +177,21 @@ def finalizer_computation(unfinalized_metrics): ) return finalized_metrics - current_round_metrics = intrinsics.federated_map( + current_round_metrics = federated_language.federated_map( finalizer_computation, max_unfinalized_metrics ) - total_rounds_metrics = intrinsics.federated_map( + total_rounds_metrics = federated_language.federated_map( finalizer_computation, state ) return measured_process.MeasuredProcessOutput( state=state, - result=intrinsics.federated_zip( + result=federated_language.federated_zip( (current_round_metrics, total_rounds_metrics) ), - measurements=intrinsics.federated_value((), placements.SERVER), + measurements=federated_language.federated_value( + (), federated_language.SERVER + ), ) return aggregation_process.AggregationProcess(init_fn, next_fn) @@ -221,7 +223,7 @@ def test_fed_eval_process_type_properties(self): eval_process = fed_eval.build_fed_eval(model_fn) self.assertIsInstance(eval_process, learning_process.LearningProcess) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( composers.LearningAlgorithmState( global_model_weights=model_weights_type, distributor=(), @@ -231,9 +233,9 @@ def test_fed_eval_process_type_properties(self): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - expected_metrics_type = computation_types.FederatedType( + expected_metrics_type = federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -245,27 +247,27 @@ def test_fed_eval_process_type_properties(self): aggregator=collections.OrderedDict(mean_value=(), mean_weight=()), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.initialize.type_signature, FunctionType(parameter=None, result=expected_state_type), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.next.type_signature, FunctionType( parameter=StructType([ ('state', expected_state_type), ( 'client_data', - computation_types.FederatedType( + federated_language.FederatedType( SequenceType( StructType([( 'temp', TensorType(dtype=np.float32, shape=[None]), )]) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), ]), @@ -274,13 +276,13 @@ def test_fed_eval_process_type_properties(self): ), ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.get_model_weights.type_signature, FunctionType( parameter=expected_state_type.member, result=model_weights_type ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.set_model_weights.type_signature, FunctionType( parameter=StructType([ @@ -344,35 +346,36 @@ def test_fed_eval_with_model_distributor(self): model_weights_type = model_weights_lib.weights_type_from_model(TestModel) def test_distributor(): - @federated_computation.federated_computation() + + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType( - model_weights_type, placements.SERVER + federated_language.FederatedType( + model_weights_type, federated_language.SERVER ), ) def next_fn(state, value): return measured_process.MeasuredProcessOutput( state, - intrinsics.federated_broadcast(value), - intrinsics.federated_value((), placements.SERVER), + federated_language.federated_broadcast(value), + federated_language.federated_value((), federated_language.SERVER), ) return distributors.DistributionProcess(init_fn, next_fn) eval_process = fed_eval.build_fed_eval(TestModel, test_distributor()) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.initialize.type_signature.result.member.distributor, test_distributor().initialize.type_signature.result.member, ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.next.type_signature.result.state.member.distributor, test_distributor().next.type_signature.result.state.member, ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( eval_process.next.type_signature.result.metrics.member.distributor, test_distributor().next.type_signature.result.measurements.member, ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_prox.py b/tensorflow_federated/python/learning/algorithms/fed_prox.py index 034555b5e5..acc38b9fe7 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_prox.py +++ b/tensorflow_federated/python/learning/algorithms/fed_prox.py @@ -24,6 +24,7 @@ from typing import Optional, Union from absl import logging +import federated_language import numpy as np import tensorflow as tf @@ -32,7 +33,6 @@ from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator @@ -207,7 +207,7 @@ def initial_model_weights_fn(): model_aggregator = mean.MeanFactory() py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) aggregator = model_aggregator.create( - model_update_type, computation_types.TensorType(np.float32) + model_update_type, federated_language.TensorType(np.float32) ) process_signature = aggregator.next.type_signature input_client_value_type = process_signature.parameter[1] # pytype: disable=unsupported-operands diff --git a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py index 0ada68f128..e2dda7a62a 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py @@ -16,10 +16,10 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language from tensorflow_federated.python.aggregators import factory_utils from tensorflow_federated.python.core.templates import iterative_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_prox @@ -171,7 +171,7 @@ def test_weighted_fed_prox_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) @@ -186,7 +186,7 @@ def test_unweighted_fed_prox_with_only_secure_aggregation(self): ), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_recon.py b/tensorflow_federated/python/learning/algorithms/fed_recon.py index 159acf388c..e677d79904 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_recon.py +++ b/tensorflow_federated/python/learning/algorithms/fed_recon.py @@ -45,6 +45,7 @@ from collections.abc import Callable from typing import Any, Optional, Union +import federated_language import numpy as np import tensorflow as tf @@ -54,11 +55,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process as measured_process_lib from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import tensor_utils @@ -97,8 +93,8 @@ def _build_reconstruction_client_work( dataset_split_fn: reconstruction_model.ReconstructionDatasetSplitFn, client_weighting: client_weight_lib.ClientWeightType, metrics_aggregator: Callable[ - [MetricFinalizersType, computation_types.StructWithPythonType], - computation_base.Computation, + [MetricFinalizersType, federated_language.StructWithPythonType], + federated_language.framework.Computation, ], ) -> client_works.ClientWorkProcess: # pylint: enable=g-bare-generic @@ -147,12 +143,12 @@ def _build_reconstruction_client_work( model_for_metadata ) element_type = tensorflow_types.to_type(model_for_metadata.input_spec) - dataset_type = computation_types.SequenceType(element_type) + dataset_type = federated_language.SequenceType(element_type) - @federated_computation.federated_computation + @federated_language.federated_computation def initialize(): # FedRecon client work is stateless (empty tuple). - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) # Metric finalizer functions that will be populated while tracing # `client_update` and used later in the federated computation. @@ -337,26 +333,30 @@ def initial_state_train_reduce(): unfinalized_metrics, ) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType(model_weights_type, placements.CLIENTS), - computation_types.FederatedType(dataset_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType( + model_weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + dataset_type, federated_language.CLIENTS + ), ) def next_fn(state, incoming_model_weights, client_datasets): del state # Unused. - client_result, unfinalized_metrics = intrinsics.federated_map( + client_result, unfinalized_metrics = federated_language.federated_map( client_update, (incoming_model_weights, client_datasets) ) metrics_aggregation_computation = metrics_aggregator( metric_finalizers, unfinalized_metrics.type_signature.member ) - finalized_metrics = intrinsics.federated_zip( + finalized_metrics = federated_language.federated_zip( collections.OrderedDict( train=metrics_aggregation_computation(unfinalized_metrics) ) ) return measured_process_lib.MeasuredProcessOutput( - state=intrinsics.federated_value((), placements.SERVER), + state=federated_language.federated_value((), federated_language.SERVER), result=client_result, measurements=finalized_metrics, ) @@ -395,8 +395,8 @@ def build_fed_recon( model_aggregator_factory: Optional[AggregationFactory] = None, metrics_aggregator: Optional[ Callable[ - [MetricFinalizersType, computation_types.StructWithPythonType], - computation_base.Computation, + [MetricFinalizersType, federated_language.StructWithPythonType], + federated_language.framework.Computation, ] ] = metrics_aggregators.sum_then_finalize, ) -> learning_process.LearningProcess: @@ -510,7 +510,7 @@ def build_initial_model_weights(): model_aggregator_factory, factory.WeightedAggregationFactory ) model_aggregator = model_aggregator_factory.create( - model_weights_type.trainable, computation_types.TensorType(np.float32) + model_weights_type.trainable, federated_language.TensorType(np.float32) ) if dataset_split_fn is None: diff --git a/tensorflow_federated/python/learning/algorithms/fed_recon_eval.py b/tensorflow_federated/python/learning/algorithms/fed_recon_eval.py index ae6161f6e6..f83cb34193 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_recon_eval.py +++ b/tensorflow_federated/python/learning/algorithms/fed_recon_eval.py @@ -28,16 +28,13 @@ import collections from typing import Any, Optional +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process as measured_process_lib from tensorflow_federated.python.learning.algorithms import fed_recon @@ -156,7 +153,7 @@ def build_initial_model_weights(): ) model_weights_type = build_initial_model_weights.type_signature.result - dataset_type = computation_types.SequenceType(batch_type) + dataset_type = federated_language.SequenceType(batch_type) if model_distributor is None: model_distributor = distributors.build_broadcast_process(model_weights_type) @@ -306,24 +303,28 @@ def initial_state_reconstruction_reduce(): 'metrics_aggregation_process', ) - @federated_computation.federated_computation + @federated_language.federated_computation def client_initialize(): return metrics_aggregation_process.initialize() - @federated_computation.federated_computation( + @federated_language.federated_computation( client_initialize.type_signature.result, - computation_types.FederatedType(model_weights_type, placements.CLIENTS), - computation_types.FederatedType(dataset_type, placements.CLIENTS), + federated_language.FederatedType( + model_weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType( + dataset_type, federated_language.CLIENTS + ), ) def client_work(state, model_weights, client_dataset): - unfinalized_metrics = intrinsics.federated_map( + unfinalized_metrics = federated_language.federated_map( client_computation, (model_weights, client_dataset) ) metrics_output = metrics_aggregation_process.next( state, unfinalized_metrics ) current_round_metrics, total_rounds_metrics = metrics_output.result - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict( eval=collections.OrderedDict( current_round_metrics=current_round_metrics, @@ -332,9 +333,9 @@ def client_work(state, model_weights, client_dataset): ) ) # Return empty result as no model update will be performed for evaluation. - empty_client_result = intrinsics.federated_value( + empty_client_result = federated_language.federated_value( client_works.ClientResult(update=(), update_weight=()), - placements.CLIENTS, + federated_language.CLIENTS, ) return measured_process_lib.MeasuredProcessOutput( metrics_output.state, @@ -348,8 +349,9 @@ def client_work(state, model_weights, client_dataset): # The evaluation will *not* send model updates back, only metrics; so the type # is simply an empty tuple. - empty_client_work_result_type = computation_types.FederatedType( - client_works.ClientResult(update=(), update_weight=()), placements.CLIENTS + empty_client_work_result_type = federated_language.FederatedType( + client_works.ClientResult(update=(), update_weight=()), + federated_language.CLIENTS, ) empty_model_update_type = empty_client_work_result_type.member.update # pytype: disable=attribute-error empty_model_update_weight_type = ( diff --git a/tensorflow_federated/python/learning/algorithms/fed_recon_eval_test.py b/tensorflow_federated/python/learning/algorithms/fed_recon_eval_test.py index 19efd5022f..0f09b23fa0 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_recon_eval_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_recon_eval_test.py @@ -18,16 +18,12 @@ from absl.testing import absltest from absl.testing import parameterized import attrs +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import measured_process as measured_process_lib from tensorflow_federated.python.learning.algorithms import fed_recon_eval from tensorflow_federated.python.learning.metrics import counters @@ -39,9 +35,9 @@ # Convenience aliases. -FunctionType = computation_types.FunctionType -SequenceType = computation_types.SequenceType -TensorType = computation_types.TensorType +FunctionType = federated_language.FunctionType +SequenceType = federated_language.SequenceType +TensorType = federated_language.TensorType LearningAlgorithmState = composers.LearningAlgorithmState LearningProcessOutput = learning_process_lib.LearningProcessOutput @@ -206,7 +202,7 @@ def metrics_fn(): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=(), @@ -223,26 +219,26 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -264,7 +260,7 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), @@ -310,7 +306,7 @@ def metrics_fn(): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=(), @@ -327,26 +323,26 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -368,7 +364,7 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), @@ -413,7 +409,7 @@ def metrics_fn(): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=(), @@ -430,26 +426,26 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -471,7 +467,7 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), @@ -510,7 +506,7 @@ def metrics_fn(): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=(), @@ -527,26 +523,26 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -568,7 +564,7 @@ def metrics_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), @@ -615,7 +611,7 @@ def dataset_split_fn(client_dataset): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=(), @@ -632,26 +628,26 @@ def dataset_split_fn(client_dataset): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -673,7 +669,7 @@ def dataset_split_fn(client_dataset): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), @@ -721,7 +717,7 @@ def loss_fn(): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=(), @@ -736,26 +732,26 @@ def loss_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=(), client_work=collections.OrderedDict( @@ -773,7 +769,7 @@ def loss_fn(): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), @@ -813,26 +809,28 @@ def build_custom_distributor( ) -> distributors.DistributionProcess: """Builds a `DistributionProcess` that wraps `tff.federated_broadcast`.""" - @federated_computation.federated_computation() + @federated_language.federated_computation() def test_server_initialization(): # Count the number of calls. - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - model_weights_type, placements.SERVER + @federated_language.federated_computation( + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + model_weights_type, federated_language.SERVER ), ) def stateful_broadcast(state, value): - test_metrics = intrinsics.federated_value(3.0, placements.SERVER) - new_state = intrinsics.federated_map( + test_metrics = federated_language.federated_value( + 3.0, federated_language.SERVER + ) + new_state = federated_language.federated_map( tensorflow_computation.tf_computation(lambda x: x + 1), state, ) return measured_process_lib.MeasuredProcessOutput( state=new_state, - result=intrinsics.federated_broadcast(value), + result=federated_language.federated_broadcast(value), measurements=test_metrics, ) @@ -851,7 +849,7 @@ def stateful_broadcast(state, value): global_weights_type = reconstruction_model.global_weights_type_from_model( model_fn() ) - state_type = computation_types.FederatedType( + state_type = federated_language.FederatedType( LearningAlgorithmState( global_model_weights=global_weights_type, distributor=np.int32, @@ -868,26 +866,26 @@ def stateful_broadcast(state, value): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ) - type_test_utils.assert_types_identical( + federated_language.framework.assert_types_identical( evaluate.next.type_signature, FunctionType( parameter=collections.OrderedDict( state=state_type, - client_data=computation_types.FederatedType( + client_data=federated_language.FederatedType( SequenceType( collections.OrderedDict( x=TensorType(np.float32, [None, 1]), y=TensorType(np.float32, [None, 1]), ) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ), result=LearningProcessOutput( state=state_type, - metrics=computation_types.FederatedType( + metrics=federated_language.FederatedType( collections.OrderedDict( distributor=np.float32, client_work=collections.OrderedDict( @@ -909,7 +907,7 @@ def stateful_broadcast(state, value): ), finalizer=(), ), - placements.SERVER, + federated_language.SERVER, ), ), ), diff --git a/tensorflow_federated/python/learning/algorithms/fed_recon_test.py b/tensorflow_federated/python/learning/algorithms/fed_recon_test.py index 71d4ef161d..732aec36b5 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_recon_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_recon_test.py @@ -18,6 +18,7 @@ from absl.testing import parameterized import attrs +import federated_language import numpy as np import tensorflow as tf import tensorflow_privacy as tfp @@ -29,10 +30,6 @@ from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process as aggregation_process_lib from tensorflow_federated.python.core.templates import iterative_process as iterative_process_lib from tensorflow_federated.python.core.templates import measured_process as measured_process_lib @@ -210,11 +207,11 @@ def __init__(self, dp_sum_factory): self._clear_sum = sum_factory.SumFactory() def create( - self, value_type: computation_types.Type + self, value_type: federated_language.Type ) -> aggregation_process_lib.AggregationProcess: self._dp_sum_process = self._dp_sum.create(value_type) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init(): # Invoke here to instantiate anything we need return self._dp_sum_process.initialize() @@ -224,17 +221,23 @@ def div(x, y): # Opaque shape manipulations return [tf.squeeze(tf.math.divide_no_nan(x, tf.cast(y, tf.float32)), 0)] - @federated_computation.federated_computation( + @federated_language.federated_computation( init.type_signature.result, - computation_types.FederatedType(value_type, placements.CLIENTS), + federated_language.FederatedType( + value_type, federated_language.CLIENTS + ), ) def next_fn(state, value): - one_at_clients = intrinsics.federated_value(1, placements.CLIENTS) + one_at_clients = federated_language.federated_value( + 1, federated_language.CLIENTS + ) dp_sum = self._dp_sum_process.next(state, value) - summed_one = intrinsics.federated_sum(one_at_clients) + summed_one = federated_language.federated_sum(one_at_clients) return measured_process_lib.MeasuredProcessOutput( state=dp_sum.state, - result=intrinsics.federated_map(div, (dp_sum.result, summed_one)), + result=federated_language.federated_map( + div, (dp_sum.result, summed_one) + ), measurements=dp_sum.measurements, ) @@ -978,21 +981,27 @@ def build_custom_stateful_distributor( ) -> distributors.DistributionProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" - @federated_computation.federated_computation() + @federated_language.federated_computation() def test_server_initialization(): - return intrinsics.federated_value(2.0, placements.SERVER) + return federated_language.federated_value( + 2.0, federated_language.SERVER + ) - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.SERVER), - computation_types.FederatedType( - model_weights_type, placements.SERVER + @federated_language.federated_computation( + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), + federated_language.FederatedType( + model_weights_type, federated_language.SERVER ), ) def stateful_broadcast(state, value): - empty_metrics = intrinsics.federated_value(1.0, placements.SERVER) + empty_metrics = federated_language.federated_value( + 1.0, federated_language.SERVER + ) return measured_process_lib.MeasuredProcessOutput( state=state, - result=intrinsics.federated_broadcast(value), + result=federated_language.federated_broadcast(value), measurements=empty_metrics, ) diff --git a/tensorflow_federated/python/learning/algorithms/fed_sgd.py b/tensorflow_federated/python/learning/algorithms/fed_sgd.py index 9efdac95b8..fbce8de61e 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_sgd.py +++ b/tensorflow_federated/python/learning/algorithms/fed_sgd.py @@ -25,6 +25,7 @@ from collections.abc import Callable, Mapping from typing import Any, Optional, Union +import federated_language import numpy as np import tensorflow as tf @@ -33,10 +34,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import tensor_utils @@ -169,12 +166,12 @@ def _build_fed_sgd_client_work( model.metric_finalizers(), ) element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) weights_type = model_weights_lib.weights_type_from_model(model) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): @@ -183,17 +180,19 @@ def client_update_computation(initial_model_weights, dataset): ) return client_update(initial_model_weights, dataset) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, model_weights, client_data): - client_result, model_outputs = intrinsics.federated_map( + client_result, model_outputs = federated_language.federated_map( client_update_computation, (model_weights, client_data) ) train_metrics = metrics_aggregation_fn(model_outputs) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( @@ -334,7 +333,7 @@ def _build_functional_fed_sgd_client_work( """ py_typecheck.check_type(model, functional.FunctionalModel) element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) def ndarray_to_tensorspec(ndarray): return tf.TensorSpec(shape=ndarray.shape, dtype=ndarray.dtype) @@ -347,9 +346,9 @@ def ndarray_to_tensorspec(ndarray): ) weights_type = tensorflow_types.to_type(weights_spec) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) @tensorflow_computation.tf_computation(weights_type, data_type) def client_update_computation(initial_model_weights, dataset): @@ -358,20 +357,22 @@ def client_update_computation(initial_model_weights, dataset): ) return client_update(initial_model_weights, dataset) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, model_weights, client_data): - client_result, unfinalized_metrics = intrinsics.federated_map( + client_result, unfinalized_metrics = federated_language.federated_map( client_update_computation, (model_weights, client_data) ) metrics_aggregation_fn = metrics_aggregator( model.finalize_metrics, unfinalized_metrics.type_signature.member ) train_metrics = metrics_aggregation_fn(unfinalized_metrics) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( @@ -495,7 +496,7 @@ def initial_model_weights_fn(): if model_aggregator is None: model_aggregator = mean.MeanFactory() aggregator = model_aggregator.create( - model_update_type, computation_types.TensorType(np.float32) + model_update_type, federated_language.TensorType(np.float32) ) if metrics_aggregator is None: diff --git a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py index e2ad6b4ee1..8162b10b66 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py @@ -16,11 +16,11 @@ from unittest import mock from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_sgd @@ -184,7 +184,7 @@ def test_no_unsecure_aggregation_with_secure_aggregator(self): model_aggregator=model_update_aggregator.secure_aggregator(), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) @@ -283,7 +283,7 @@ def test_no_unsecure_aggregation_with_secure_aggregator(self): model_aggregator=model_update_aggregator.secure_aggregator(), metrics_aggregator=aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) diff --git a/tensorflow_federated/python/learning/algorithms/kmeans_clustering.py b/tensorflow_federated/python/learning/algorithms/kmeans_clustering.py index c737e366e5..21582e5f9c 100644 --- a/tensorflow_federated/python/learning/algorithms/kmeans_clustering.py +++ b/tensorflow_federated/python/learning/algorithms/kmeans_clustering.py @@ -22,6 +22,7 @@ import collections from typing import Optional +import federated_language import numpy as np import tensorflow as tf @@ -29,10 +30,6 @@ from tensorflow_federated.python.aggregators import factory_utils from tensorflow_federated.python.aggregators import sum_factory from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.learning.templates import composers @@ -123,29 +120,31 @@ def reduce_fn(state, point): def _build_kmeans_client_work( - centroids_type: computation_types.TensorType, - data_type: computation_types.SequenceType, + centroids_type: federated_language.TensorType, + data_type: federated_language.SequenceType, ): """Creates a `tff.learning.templates.ClientWorkProcess` for k-means.""" - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) @tensorflow_computation.tf_computation(centroids_type, data_type) def client_update(centroids, client_data): return _compute_kmeans_step(centroids, client_data) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(centroids_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + centroids_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, cluster_centers, client_data): - client_result, stat_output = intrinsics.federated_map( + client_result, stat_output = federated_language.federated_map( client_update, (cluster_centers, client_data) ) - stat_metrics = intrinsics.federated_sum(stat_output) + stat_metrics = federated_language.federated_sum(stat_output) return measured_process.MeasuredProcessOutput( state, client_result, stat_metrics ) @@ -196,7 +195,7 @@ def _update_centroids( def _build_kmeans_finalizer( - centroids_type: computation_types.Type, num_centroids: int + centroids_type: federated_language.Type, num_centroids: int ): """Builds a `tff.learning.templates.FinalizerProcess` for k-means.""" @@ -204,9 +203,11 @@ def _build_kmeans_finalizer( def initialize_weights(): return tf.ones((num_centroids,), dtype=_WEIGHT_DTYPE) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_eval(initialize_weights, placements.SERVER) + return federated_language.federated_eval( + initialize_weights, federated_language.SERVER + ) weights_type = initialize_weights.type_signature.result @@ -220,23 +221,27 @@ def server_update_tf( current_centroids, current_weights, new_centroid_sums, new_weights ) - summed_updates_type = computation_types.FederatedType( - computation_types.to_type((centroids_type, weights_type)), - placements.SERVER, + summed_updates_type = federated_language.FederatedType( + federated_language.to_type((centroids_type, weights_type)), + federated_language.SERVER, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(centroids_type, placements.SERVER), + federated_language.FederatedType( + centroids_type, federated_language.SERVER + ), summed_updates_type, ) def next_fn(state, current_centroids, summed_updates): new_centroid_sums, new_weights = summed_updates - updated_centroids, updated_weights = intrinsics.federated_map( + updated_centroids, updated_weights = federated_language.federated_map( server_update_tf, (current_centroids, state, new_centroid_sums, new_weights), ) - empty_measurements = intrinsics.federated_value((), placements.SERVER) + empty_measurements = federated_language.federated_value( + (), federated_language.SERVER + ) return measured_process.MeasuredProcessOutput( updated_weights, updated_centroids, empty_measurements ) @@ -325,12 +330,12 @@ def initialize_centers(): centroids_shape, random_seed, dtype=_POINT_DTYPE ) - centroids_type = computation_types.TensorType(_POINT_DTYPE, centroids_shape) - weights_type = computation_types.TensorType( + centroids_type = federated_language.TensorType(_POINT_DTYPE, centroids_shape) + weights_type = federated_language.TensorType( _WEIGHT_DTYPE, shape=(num_clusters,) ) - point_type = computation_types.TensorType(_POINT_DTYPE, shape=data_shape) - data_type = computation_types.SequenceType(point_type) + point_type = federated_language.TensorType(_POINT_DTYPE, shape=data_shape) + data_type = federated_language.SequenceType(point_type) if distributor is None: distributor = distributors.build_broadcast_process(centroids_type) @@ -342,10 +347,10 @@ def initialize_centers(): # We wrap the sum factory as a weighted aggregator for compatibility with # the learning process composer. weighted_aggregator = factory_utils.as_weighted_aggregator(sum_aggregator) - value_type = computation_types.to_type((centroids_type, weights_type)) + value_type = federated_language.to_type((centroids_type, weights_type)) aggregator = weighted_aggregator.create( value_type, - computation_types.to_type(()), # pytype: disable=wrong-arg-types + federated_language.to_type(()), # pytype: disable=wrong-arg-types ) finalizer = _build_kmeans_finalizer(centroids_type, num_clusters) diff --git a/tensorflow_federated/python/learning/algorithms/kmeans_clustering_test.py b/tensorflow_federated/python/learning/algorithms/kmeans_clustering_test.py index 1639d191ba..3bc5b38ca5 100644 --- a/tensorflow_federated/python/learning/algorithms/kmeans_clustering_test.py +++ b/tensorflow_federated/python/learning/algorithms/kmeans_clustering_test.py @@ -15,12 +15,11 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.algorithms import kmeans_clustering _WEIGHT_DTYPE = kmeans_clustering._WEIGHT_DTYPE @@ -142,11 +141,13 @@ def test_build_kmeans_client_work_with_different_shapes(self, shape): point_dtype = np.float32 num_clusters = 5 centroids_shape = (num_clusters,) + shape - centroids_type = computation_types.TensorType(point_dtype, centroids_shape) - point_type = computation_types.TensorType(point_dtype, shape) - data_type = computation_types.SequenceType(point_type) - weight_type = computation_types.TensorType(_WEIGHT_DTYPE, (num_clusters,)) - empty_server_type = computation_types.FederatedType((), placements.SERVER) + centroids_type = federated_language.TensorType(point_dtype, centroids_shape) + point_type = federated_language.TensorType(point_dtype, shape) + data_type = federated_language.SequenceType(point_type) + weight_type = federated_language.TensorType(_WEIGHT_DTYPE, (num_clusters,)) + empty_server_type = federated_language.FederatedType( + (), federated_language.SERVER + ) client_work = kmeans_clustering._build_kmeans_client_work( centroids_type, data_type @@ -155,19 +156,21 @@ def test_build_kmeans_client_work_with_different_shapes(self, shape): next_type = client_work.next.type_signature next_type.parameter[0].check_equivalent_to(empty_server_type) next_type.parameter[1].check_equivalent_to( - computation_types.FederatedType(centroids_type, placements.CLIENTS) + federated_language.FederatedType( + centroids_type, federated_language.CLIENTS + ) ) next_type.parameter[2].check_equivalent_to( - computation_types.FederatedType(data_type, placements.CLIENTS) + federated_language.FederatedType(data_type, federated_language.CLIENTS) ) next_type.result[0].check_equivalent_to(empty_server_type) next_type.result[1].member.update.check_equivalent_to( - computation_types.to_type((centroids_type, weight_type)) + federated_language.to_type((centroids_type, weight_type)) ) - expected_measurements_type = computation_types.to_type( + expected_measurements_type = federated_language.to_type( collections.OrderedDict( - num_examples=computation_types.TensorType(_WEIGHT_DTYPE) + num_examples=federated_language.TensorType(_WEIGHT_DTYPE) ) ) next_type.result[2].member.check_equivalent_to(expected_measurements_type) @@ -182,11 +185,13 @@ def test_build_kmeans_client_work_with_different_dtypes(self, point_dtype): shape = (3, 2) num_clusters = 5 centroids_shape = (num_clusters,) + shape - centroids_type = computation_types.TensorType(point_dtype, centroids_shape) - point_type = computation_types.TensorType(point_dtype, shape) - data_type = computation_types.SequenceType(point_type) - weight_type = computation_types.TensorType(_WEIGHT_DTYPE, (num_clusters,)) - empty_server_type = computation_types.FederatedType((), placements.SERVER) + centroids_type = federated_language.TensorType(point_dtype, centroids_shape) + point_type = federated_language.TensorType(point_dtype, shape) + data_type = federated_language.SequenceType(point_type) + weight_type = federated_language.TensorType(_WEIGHT_DTYPE, (num_clusters,)) + empty_server_type = federated_language.FederatedType( + (), federated_language.SERVER + ) client_work = kmeans_clustering._build_kmeans_client_work( centroids_type, data_type @@ -195,19 +200,21 @@ def test_build_kmeans_client_work_with_different_dtypes(self, point_dtype): next_type = client_work.next.type_signature next_type.parameter[0].check_equivalent_to(empty_server_type) next_type.parameter[1].check_equivalent_to( - computation_types.FederatedType(centroids_type, placements.CLIENTS) + federated_language.FederatedType( + centroids_type, federated_language.CLIENTS + ) ) next_type.parameter[2].check_equivalent_to( - computation_types.FederatedType(data_type, placements.CLIENTS) + federated_language.FederatedType(data_type, federated_language.CLIENTS) ) next_type.result[0].check_equivalent_to(empty_server_type) next_type.result[1].member.update.check_equivalent_to( - computation_types.to_type((centroids_type, weight_type)) + federated_language.to_type((centroids_type, weight_type)) ) - expected_measurements_type = computation_types.to_type( + expected_measurements_type = federated_language.to_type( collections.OrderedDict( - num_examples=computation_types.TensorType(_WEIGHT_DTYPE) + num_examples=federated_language.TensorType(_WEIGHT_DTYPE) ) ) next_type.result[2].member.check_equivalent_to(expected_measurements_type) diff --git a/tensorflow_federated/python/learning/algorithms/mime.py b/tensorflow_federated/python/learning/algorithms/mime.py index 239365f94b..4ccb93fcd7 100644 --- a/tensorflow_federated/python/learning/algorithms/mime.py +++ b/tensorflow_federated/python/learning/algorithms/mime.py @@ -26,6 +26,7 @@ from collections.abc import Callable, Mapping, Sequence from typing import Any, Optional, Union +import federated_language import numpy as np import tensorflow as tf @@ -36,10 +37,6 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder @@ -241,25 +238,25 @@ def _build_mime_lite_client_work( model.metric_finalizers(), ) element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) weights_type = model_weights_lib.weights_type_from_model(model) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(weights_type) full_gradient_aggregator = full_gradient_aggregator.create( - weights_type.trainable, computation_types.TensorType(np.float32) + weights_type.trainable, federated_language.TensorType(np.float32) ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): specs = weight_tensor_specs.trainable - optimizer_state = intrinsics.federated_eval( + optimizer_state = federated_language.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(specs) ), - placements.SERVER, + federated_language.SERVER, ) aggregator_state = full_gradient_aggregator.initialize() - return intrinsics.federated_zip((optimizer_state, aggregator_state)) + return federated_language.federated_zip((optimizer_state, aggregator_state)) client_update_fn = _build_client_update_fn_for_mime_lite( model_fn, @@ -278,30 +275,36 @@ def update_optimizer_state(state, aggregate_gradient): updated_state, _ = optimizer.next(state, whimsy_weights, aggregate_gradient) return updated_state - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): optimizer_state, aggregator_state = state - optimizer_state_at_clients = intrinsics.federated_broadcast(optimizer_state) - client_result, model_outputs, full_gradient = intrinsics.federated_map( - client_update_fn, (optimizer_state_at_clients, weights, client_data) + optimizer_state_at_clients = federated_language.federated_broadcast( + optimizer_state + ) + client_result, model_outputs, full_gradient = ( + federated_language.federated_map( + client_update_fn, (optimizer_state_at_clients, weights, client_data) + ) ) full_gradient_agg_output = full_gradient_aggregator.next( aggregator_state, full_gradient, client_result.update_weight ) - updated_optimizer_state = intrinsics.federated_map( + updated_optimizer_state = federated_language.federated_map( update_optimizer_state, (optimizer_state, full_gradient_agg_output.result), ) - new_state = intrinsics.federated_zip( + new_state = federated_language.federated_zip( (updated_optimizer_state, full_gradient_agg_output.state) ) train_metrics = metrics_aggregation_fn(model_outputs) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( @@ -519,7 +522,7 @@ def _build_mime_lite_functional_client_work( metrics_aggregator = metric_aggregator.sum_then_finalize element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) def ndarray_to_tensorspec(ndarray): return tf.TensorSpec( @@ -536,19 +539,19 @@ def ndarray_to_tensorspec(ndarray): full_gradient_aggregator = full_gradient_aggregator.create( weights_type.trainable, # pytype: disable=attribute-error - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - optimizer_state = intrinsics.federated_eval( + optimizer_state = federated_language.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(weight_tensor_specs.trainable) ), - placements.SERVER, + federated_language.SERVER, ) aggregator_state = full_gradient_aggregator.initialize() - return intrinsics.federated_zip((optimizer_state, aggregator_state)) + return federated_language.federated_zip((optimizer_state, aggregator_state)) client_update_fn = _build_functional_client_update_fn_for_mime_lite( model=model, @@ -570,28 +573,32 @@ def update_optimizer_state(state, aggregate_gradient): updated_state, _ = optimizer.next(state, whimsy_weights, aggregate_gradient) return updated_state - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): optimizer_state, aggregator_state = state - optimizer_state_at_clients = intrinsics.federated_broadcast(optimizer_state) + optimizer_state_at_clients = federated_language.federated_broadcast( + optimizer_state + ) client_result, unfinalized_metrics, full_gradient = ( - intrinsics.federated_map( + federated_language.federated_map( client_update_fn, (optimizer_state_at_clients, weights, client_data) ) ) full_gradient_agg_output = full_gradient_aggregator.next( aggregator_state, full_gradient, client_result.update_weight ) - updated_optimizer_state = intrinsics.federated_map( + updated_optimizer_state = federated_language.federated_map( update_optimizer_state, (optimizer_state, full_gradient_agg_output.result), ) - new_state = intrinsics.federated_zip( + new_state = federated_language.federated_zip( (updated_optimizer_state, full_gradient_agg_output.state) ) @@ -599,7 +606,7 @@ def next_fn(state, weights, client_data): model.finalize_metrics, unfinalized_metrics.type_signature.member ) train_metrics = metrics_aggregation_fn(unfinalized_metrics) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( @@ -694,15 +701,16 @@ def initialize_learning_rate(mime_state): mime_state[0][optimizer_base.LEARNING_RATE_KEY] = learning_rate_fn(0) return mime_state - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): initial_state = client_work.initialize() - updated_state = intrinsics.federated_map( + updated_state = federated_language.federated_map( initialize_learning_rate, initial_state ) - return intrinsics.federated_zip( - (intrinsics.federated_value(0, placements.SERVER), updated_state) - ) + return federated_language.federated_zip(( + federated_language.federated_value(0, federated_language.SERVER), + updated_state, + )) state_type = init_fn.type_signature.result.member @@ -715,17 +723,21 @@ def update_state(state): mime_state[0][optimizer_base.LEARNING_RATE_KEY] = updated_learning_rate return (updated_round_num, mime_state) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): round_num, mime_state = state output = client_work.next(mime_state, weights, client_data) updated_mime_state = output.state - outer_state = intrinsics.federated_zip((round_num, updated_mime_state)) - updated_state = intrinsics.federated_map(update_state, outer_state) + outer_state = federated_language.federated_zip( + (round_num, updated_mime_state) + ) + updated_state = federated_language.federated_map(update_state, outer_state) return measured_process.MeasuredProcessOutput( updated_state, output.result, output.measurements ) @@ -881,7 +893,7 @@ def initial_model_weights_fn(): py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) model_update_type = model_weights_type.trainable model_aggregator = model_aggregator.create( - model_update_type, computation_types.TensorType(np.float32) + model_update_type, federated_language.TensorType(np.float32) ) if full_gradient_aggregator is None: full_gradient_aggregator = mean.MeanFactory() @@ -1202,7 +1214,7 @@ def initial_model_weights_fn(): py_typecheck.check_type(model_aggregator, factory.WeightedAggregationFactory) model_update_type = model_weights_type.trainable model_aggregator = model_aggregator.create( - model_update_type, computation_types.TensorType(np.float32) + model_update_type, federated_language.TensorType(np.float32) ) if full_gradient_aggregator is None: full_gradient_aggregator = mean.MeanFactory() diff --git a/tensorflow_federated/python/learning/algorithms/mime_test.py b/tensorflow_federated/python/learning/algorithms/mime_test.py index fe699abee3..ec4cea97a3 100644 --- a/tensorflow_federated/python/learning/algorithms/mime_test.py +++ b/tensorflow_federated/python/learning/algorithms/mime_test.py @@ -16,6 +16,7 @@ from unittest import mock from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -25,13 +26,8 @@ from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator @@ -71,24 +67,27 @@ def test_type_properties(self, weighting): self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) mw_type = model_weights.ModelWeights( - trainable=computation_types.to_type([(np.float32, (2, 1)), np.float32]), - non_trainable=computation_types.to_type([np.float32]), + trainable=federated_language.to_type( + [(np.float32, (2, 1)), np.float32] + ), + non_trainable=federated_language.to_type([np.float32]), ) - expected_param_model_weights_type = computation_types.FederatedType( - mw_type, placements.CLIENTS + expected_param_model_weights_type = federated_language.FederatedType( + mw_type, federated_language.CLIENTS ) element_type = tensorflow_types.to_type(model_fn().input_spec) - expected_param_data_type = computation_types.FederatedType( - computation_types.SequenceType(element_type), placements.CLIENTS + expected_param_data_type = federated_language.FederatedType( + federated_language.SequenceType(element_type), + federated_language.CLIENTS, ) - expected_result_type = computation_types.FederatedType( + expected_result_type = federated_language.FederatedType( client_works.ClientResult( update=mw_type.trainable, - update_weight=computation_types.TensorType(np.float32), + update_weight=federated_language.TensorType(np.float32), ), - placements.CLIENTS, + federated_language.CLIENTS, ) - expected_optimizer_state_type = computation_types.StructWithPythonType( + expected_optimizer_state_type = federated_language.StructWithPythonType( collections.OrderedDict( learning_rate=np.float32, momentum=np.float32, @@ -96,30 +95,30 @@ def test_type_properties(self, weighting): ), collections.OrderedDict, ) - expected_aggregator_type = computation_types.to_type( + expected_aggregator_type = federated_language.to_type( collections.OrderedDict(value_sum_process=(), weight_sum_process=()) ) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( (expected_optimizer_state_type, expected_aggregator_type), - placements.SERVER, + federated_language.SERVER, ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( train=collections.OrderedDict( loss=np.float32, num_examples=np.int32 ) ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) expected_initialize_type.check_equivalent_to( client_work_process.initialize.type_signature ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_model_weights_type, @@ -227,9 +226,10 @@ def test_execution_with_optimizer(self, optimizer): def test_custom_metrics_aggregator(self): def sum_then_finalize_then_times_two(metric_finalizers): - @federated_computation.federated_computation + + @federated_language.federated_computation def aggregation_computation(client_local_unfinalized_metrics): - unfinalized_metrics_sum = intrinsics.federated_sum( + unfinalized_metrics_sum = federated_language.federated_sum( client_local_unfinalized_metrics ) @@ -242,7 +242,7 @@ def finalizer_computation(unfinalized_metrics): ) return finalized_metrics - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, unfinalized_metrics_sum ) @@ -318,13 +318,13 @@ def sum_then_finalize_then_times_two( metric_finalizers, local_unfinalized_metrics_type ): - @federated_computation.federated_computation( - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ) ) def aggregation_computation(client_local_unfinalized_metrics): - unfinalized_metrics_sum = intrinsics.federated_sum( + unfinalized_metrics_sum = federated_language.federated_sum( client_local_unfinalized_metrics ) @@ -338,7 +338,7 @@ def finalizer_computation(unfinalized_metrics): ) return finalized_metrics - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, unfinalized_metrics_sum ) @@ -479,7 +479,7 @@ def test_weighted_mime_lite_with_only_secure_aggregation(self): full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) @@ -492,7 +492,7 @@ def test_unweighted_mime_lite_with_only_secure_aggregation(self): full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize, ) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( learning_process.next ) diff --git a/tensorflow_federated/python/learning/algorithms/personalization_eval.py b/tensorflow_federated/python/learning/algorithms/personalization_eval.py index c8c503d91b..facf6f4275 100644 --- a/tensorflow_federated/python/learning/algorithms/personalization_eval.py +++ b/tensorflow_federated/python/learning/algorithms/personalization_eval.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Mapping from typing import Any +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import sampling @@ -24,12 +25,6 @@ from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import array_shape -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.models import variable @@ -54,7 +49,7 @@ def build_personalization_eval_computation( [variable.VariableModel, tf.data.Dataset], _MetricsType ], max_num_clients: int = 100, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Builds the TFF computation for evaluating personalization strategies. The returned TFF computation broadcasts model weights from `tff.SERVER` to @@ -135,10 +130,10 @@ def build_personalization_eval_computation( # input should contain unbatched elements. element_tff_type = _remove_batch_dim(batch_tff_type) client_input_type = collections.OrderedDict( - train_data=computation_types.SequenceType(element_tff_type), - test_data=computation_types.SequenceType(element_tff_type), + train_data=federated_language.SequenceType(element_tff_type), + test_data=federated_language.SequenceType(element_tff_type), ) - client_input_type = computation_types.to_type(client_input_type) + client_input_type = federated_language.to_type(client_input_type) py_typecheck.check_type(max_num_clients, int) if max_num_clients <= 0: @@ -159,14 +154,20 @@ def build_personalization_eval_computation( client_computation.type_signature.result ) - @federated_computation.federated_computation( - computation_types.FederatedType(model_weights_type, placements.SERVER), - computation_types.FederatedType(client_input_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType( + model_weights_type, federated_language.SERVER + ), + federated_language.FederatedType( + client_input_type, federated_language.CLIENTS + ), ) def personalization_eval(server_model_weights, federated_client_input): """TFF orchestration logic.""" - client_init_weights = intrinsics.federated_broadcast(server_model_weights) - client_final_metrics = intrinsics.federated_map( + client_init_weights = federated_language.federated_broadcast( + server_model_weights + ) + client_final_metrics = federated_language.federated_map( client_computation, (client_init_weights, federated_client_input) ) @@ -184,8 +185,8 @@ def personalization_eval(server_model_weights, federated_client_input): def _build_client_computation( - model_weights_type: computation_types.Type, - client_data_type: computation_types.Type, + model_weights_type: federated_language.Type, + client_data_type: federated_language.Type, model_fn: Callable[[], variable.VariableModel], personalize_fn_dict: Mapping[str, Callable[[], _FinetuneEvalFnType]], baseline_evaluate_fn: Callable[ @@ -230,8 +231,8 @@ def _client_computation(initial_model_weights, client_input): def _remove_batch_dim( - type_spec: computation_types.Type, -) -> computation_types.Type: + type_spec: federated_language.Type, +) -> federated_language.Type: """Removes the batch dimension from the `tff.TensorType`s in `type_spec`. Args: @@ -249,11 +250,12 @@ def _remove_batch_dim( def _remove_first_dim_in_tensortype(tensor_type): """Return a new `tff.TensorType` after removing the first dimension.""" - py_typecheck.check_type(tensor_type, computation_types.TensorType) - if tensor_type.shape is not None and not array_shape.is_shape_scalar( - tensor_type.shape + py_typecheck.check_type(tensor_type, federated_language.TensorType) + if ( + tensor_type.shape is not None + and not federated_language.array_shape_is_scalar(tensor_type.shape) ): - return computation_types.TensorType( + return federated_language.TensorType( shape=tensor_type.shape[1:], dtype=tensor_type.dtype ) else: diff --git a/tensorflow_federated/python/learning/algorithms/personalization_eval_test.py b/tensorflow_federated/python/learning/algorithms/personalization_eval_test.py index f97dc53ad9..fe5bf36352 100644 --- a/tensorflow_federated/python/learning/algorithms/personalization_eval_test.py +++ b/tensorflow_federated/python/learning/algorithms/personalization_eval_test.py @@ -16,11 +16,11 @@ from unittest import mock from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning.algorithms import personalization_eval as p13n_eval from tensorflow_federated.python.learning.models import keras_utils @@ -278,12 +278,12 @@ def model_fn(): ), ( 'tff_struct_with_python_type', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( collections.OrderedDict( - x=computation_types.TensorType( + x=federated_language.TensorType( shape=[None, 2], dtype=np.float32 ), - y=computation_types.TensorType( + y=federated_language.TensorType( shape=[None, 1], dtype=np.float32 ), ), @@ -292,12 +292,12 @@ def model_fn(): ), ( 'tff_struct_type', - computation_types.StructType( + federated_language.StructType( collections.OrderedDict( - x=computation_types.TensorType( + x=federated_language.TensorType( shape=[None, 2], dtype=np.float32 ), - y=computation_types.TensorType( + y=federated_language.TensorType( shape=[None, 1], dtype=np.float32 ), ) diff --git a/tensorflow_federated/python/learning/debug_measurements.py b/tensorflow_federated/python/learning/debug_measurements.py index 0503f2442b..0a6c970703 100644 --- a/tensorflow_federated/python/learning/debug_measurements.py +++ b/tensorflow_federated/python/learning/debug_measurements.py @@ -17,13 +17,12 @@ from collections.abc import Callable from typing import Any, TypeVar +import federated_language import tensorflow as tf from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import measurements from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import placements _AggregationFactory = TypeVar( @@ -133,23 +132,27 @@ def _calculate_unbiased_std_dev( def _calculate_client_update_statistics_with_norm(client_norms, client_weights): """Calculate client updates with client norms.""" - client_norms_squared = intrinsics.federated_map(_square_value, client_norms) + client_norms_squared = federated_language.federated_map( + _square_value, client_norms + ) - average_client_norm = intrinsics.federated_mean(client_norms, client_weights) - average_client_norm_squared = intrinsics.federated_mean( + average_client_norm = federated_language.federated_mean( + client_norms, client_weights + ) + average_client_norm_squared = federated_language.federated_mean( client_norms_squared, client_weights ) # TODO: b/197972289 - Add SecAgg compatibility to these measurements - sum_of_client_weights = intrinsics.federated_sum(client_weights) - client_weights_squared = intrinsics.federated_map( + sum_of_client_weights = federated_language.federated_sum(client_weights) + client_weights_squared = federated_language.federated_map( _square_value, client_weights ) - sum_of_client_weights_squared = intrinsics.federated_sum( + sum_of_client_weights_squared = federated_language.federated_sum( client_weights_squared ) - unbiased_std_dev = intrinsics.federated_map( + unbiased_std_dev = federated_language.federated_map( _calculate_unbiased_std_dev, ( average_client_norm, @@ -159,7 +162,7 @@ def _calculate_client_update_statistics_with_norm(client_norms, client_weights): ), ) - return intrinsics.federated_zip( + return federated_language.federated_zip( collections.OrderedDict( average_client_norm=average_client_norm, std_dev_client_norm=unbiased_std_dev, @@ -169,7 +172,7 @@ def _calculate_client_update_statistics_with_norm(client_norms, client_weights): def _calculate_client_update_statistics(client_updates, client_weights): """Calculate the average and standard deviation of client updates.""" - client_norms = intrinsics.federated_map( + client_norms = federated_language.federated_map( _calculate_global_norm, client_updates ) return _calculate_client_update_statistics_with_norm( @@ -181,7 +184,7 @@ def _calculate_client_update_statistics_mixed_dtype( client_updates, client_weights ): """Calculate client update statistics of mixed data types.""" - client_norms = intrinsics.federated_map( + client_norms = federated_language.federated_map( _calculate_global_norm_mixed_dtype, client_updates ) return _calculate_client_update_statistics_with_norm( @@ -235,11 +238,15 @@ def _build_aggregator_measurement_fns( else: def federated_client_measurement_fn(value): - client_weights = intrinsics.federated_value(1.0, placements.CLIENTS) + client_weights = federated_language.federated_value( + 1.0, federated_language.CLIENTS + ) return client_measurement_fn(value, client_weights) def federated_server_measurement_fn(value): - server_measurements = intrinsics.federated_map(server_measurement_fn, value) + server_measurements = federated_language.federated_map( + server_measurement_fn, value + ) return server_measurements return federated_client_measurement_fn, federated_server_measurement_fn diff --git a/tensorflow_federated/python/learning/debug_measurements_test.py b/tensorflow_federated/python/learning/debug_measurements_test.py index 51e09fefc5..c29bc00209 100644 --- a/tensorflow_federated/python/learning/debug_measurements_test.py +++ b/tensorflow_federated/python/learning/debug_measurements_test.py @@ -15,35 +15,37 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning import debug_measurements -FloatType = computation_types.TensorType(np.float32) -FloatAtServer = computation_types.FederatedType(FloatType, placements.SERVER) -FloatAtClients = computation_types.FederatedType(FloatType, placements.CLIENTS) +FloatType = federated_language.TensorType(np.float32) +FloatAtServer = federated_language.FederatedType( + FloatType, federated_language.SERVER +) +FloatAtClients = federated_language.FederatedType( + FloatType, federated_language.CLIENTS +) -SERVER_MEASUREMENTS_OUTPUT_TYPE = computation_types.FederatedType( +SERVER_MEASUREMENTS_OUTPUT_TYPE = federated_language.FederatedType( collections.OrderedDict([ ('server_update_max', FloatType), ('server_update_norm', FloatType), ('server_update_min', FloatType), ]), - placements.SERVER, + federated_language.SERVER, ) -CLIENT_MEASUREMENTS_OUTPUT_TYPE = computation_types.FederatedType( +CLIENT_MEASUREMENTS_OUTPUT_TYPE = federated_language.FederatedType( collections.OrderedDict([ ('average_client_norm', FloatType), ('std_dev_client_norm', FloatType), ]), - placements.SERVER, + federated_language.SERVER, ) @@ -51,12 +53,12 @@ class DebugMeasurementsTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('scalar_type', FloatType), - ('vector_type', computation_types.TensorType(np.float32, [3])), + ('vector_type', federated_language.TensorType(np.float32, [3])), ('struct_type', [FloatType, FloatType]), ( 'nested_struct_type', [ - [computation_types.TensorType(np.float32, [3])], + [federated_language.TensorType(np.float32, [3])], [FloatType, FloatType], ], ), @@ -67,9 +69,11 @@ def test_server_measurement_fn_traceable_by_federated_computation( _, server_measurement_fn = ( debug_measurements._build_aggregator_measurement_fns() ) - input_type = computation_types.FederatedType(value_type, placements.SERVER) + input_type = federated_language.FederatedType( + value_type, federated_language.SERVER + ) - @federated_computation.federated_computation(input_type) + @federated_language.federated_computation(input_type) def get_server_measurements(server_update): return server_measurement_fn(server_update) @@ -79,12 +83,12 @@ def get_server_measurements(server_update): @parameterized.named_parameters( ('scalar_type', FloatType), - ('vector_type', computation_types.TensorType(np.float32, [3])), + ('vector_type', federated_language.TensorType(np.float32, [3])), ('struct_type', [FloatType, FloatType]), ( 'nested_struct_type', [ - [computation_types.TensorType(np.float32, [3])], + [federated_language.TensorType(np.float32, [3])], [FloatType, FloatType], ], ), @@ -97,9 +101,11 @@ def test_unweighted_client_measurement_fn_traceable_by_federated_computation( weighted_aggregator=False ) ) - input_type = computation_types.FederatedType(value_type, placements.CLIENTS) + input_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS + ) - @federated_computation.federated_computation(input_type) + @federated_language.federated_computation(input_type) def get_client_measurements(client_update): return client_measurement_fn(client_update) @@ -109,12 +115,12 @@ def get_client_measurements(client_update): @parameterized.named_parameters( ('scalar_type', FloatType), - ('vector_type', computation_types.TensorType(np.float32, [3])), + ('vector_type', federated_language.TensorType(np.float32, [3])), ('struct_type', [FloatType, FloatType]), ( 'nested_struct_type', [ - [computation_types.TensorType(np.float32, [3])], + [federated_language.TensorType(np.float32, [3])], [FloatType, FloatType], ], ), @@ -127,12 +133,14 @@ def test_weighted_client_measurement_fn_traceable_by_federated_computation( weighted_aggregator=True ) ) - input_type = computation_types.FederatedType(value_type, placements.CLIENTS) - weights_type = computation_types.FederatedType( - np.float32, placements.CLIENTS + input_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS + ) + weights_type = federated_language.FederatedType( + np.float32, federated_language.CLIENTS ) - @federated_computation.federated_computation(input_type, weights_type) + @federated_language.federated_computation(input_type, weights_type) def get_client_measurements(client_update, client_weights): return client_measurement_fn(client_update, client_weights) @@ -237,9 +245,13 @@ def test_correctness_of_unweighted_client_update_statistics( ): client_weights = [1.0 for _ in client_updates] - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def compute_client_statistics(client_updates, client_weights): return debug_measurements._calculate_client_update_statistics( @@ -272,9 +284,13 @@ def test_correctness_of_weighted_client_update_statistics( self, client_updates, client_weights ): - @federated_computation.federated_computation( - computation_types.FederatedType(np.float32, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def compute_client_statistics(client_updates, client_weights): return debug_measurements._calculate_client_update_statistics( @@ -329,9 +345,13 @@ def test_correctness_of_weighted_client_update_statistics_mixed_dtype( self, client_updates, client_weights, client_type_spec ): - @federated_computation.federated_computation( - computation_types.FederatedType(client_type_spec, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType( + client_type_spec, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def compute_client_statistics(client_updates, client_weights): return debug_measurements._calculate_client_update_statistics_mixed_dtype( @@ -385,9 +405,13 @@ def test_inf_error_in_client_update_statistics( self, client_updates, client_weights, client_type_spec ): - @federated_computation.federated_computation( - computation_types.FederatedType(client_type_spec, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType( + client_type_spec, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def compute_client_statistics(client_updates, client_weights): return debug_measurements._calculate_client_update_statistics_mixed_dtype( @@ -415,9 +439,13 @@ def test_type_error_in_client_update_statistics_with_int32( ): with self.assertRaises(TypeError): - @federated_computation.federated_computation( - computation_types.FederatedType(client_type_spec, placements.CLIENTS), - computation_types.FederatedType(np.float32, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType( + client_type_spec, federated_language.CLIENTS + ), + federated_language.FederatedType( + np.float32, federated_language.CLIENTS + ), ) def compute_client_statistics(client_updates, client_weights): return ( @@ -431,7 +459,7 @@ def compute_client_statistics(client_updates, client_weights): def test_add_measurements_to_weighted_aggregation_factory_types(self): mean_factory = mean.MeanFactory() debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) mean_aggregator = mean_factory.create(value_type, value_type) debug_aggregator = debug_mean_factory.create(value_type, value_type) self.assertTrue(debug_aggregator.is_weighted) @@ -455,7 +483,7 @@ def test_add_measurements_to_weighted_aggregation_factory_types(self): def test_add_measurements_to_weighted_aggregation_factory_output(self): mean_factory = mean.MeanFactory() debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) mean_aggregator = mean_factory.create(value_type, value_type) debug_aggregator = debug_mean_factory.create(value_type, value_type) @@ -491,7 +519,7 @@ def test_add_measurements_to_weighted_aggregation_factory_output(self): def test_add_measurements_to_unweighted_aggregation_factory_types(self): mean_factory = mean.UnweightedMeanFactory() debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) mean_aggregator = mean_factory.create(value_type) debug_aggregator = debug_mean_factory.create(value_type) self.assertFalse(debug_aggregator.is_weighted) @@ -515,7 +543,7 @@ def test_add_measurements_to_unweighted_aggregation_factory_types(self): def test_add_measurements_to_unweighted_aggregation_factory_output(self): mean_factory = mean.UnweightedMeanFactory() debug_mean_factory = debug_measurements.add_debug_measurements(mean_factory) - value_type = computation_types.TensorType(np.float32) + value_type = federated_language.TensorType(np.float32) mean_aggregator = mean_factory.create(value_type) debug_aggregator = debug_mean_factory.create(value_type) diff --git a/tensorflow_federated/python/learning/metrics/BUILD b/tensorflow_federated/python/learning/metrics/BUILD index cf12cc7938..793a9d08af 100644 --- a/tensorflow_federated/python/learning/metrics/BUILD +++ b/tensorflow_federated/python/learning/metrics/BUILD @@ -46,13 +46,10 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -66,12 +63,10 @@ py_test( "//tensorflow_federated/python/aggregators:quantile_estimation", "//tensorflow_federated/python/core/backends/test:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:estimation_process", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -83,8 +78,7 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -95,7 +89,7 @@ py_test( ":aggregation_utils", ":keras_finalizer", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -109,11 +103,8 @@ py_library( ":types", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/templates:iterative_process", + "@federated_language//federated_language", ], ) @@ -125,10 +116,7 @@ py_test( ":keras_finalizer", ":sum_aggregation_factory", "//tensorflow_federated/python/core/backends/test:execution_contexts", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/test:static_assert", + "@federated_language//federated_language", ], ) @@ -156,7 +144,7 @@ py_test( ":keras_finalizer", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -182,18 +170,14 @@ py_test( ":sampling_aggregation_factory", ":types", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) py_library( name = "types", srcs = ["types.py"], - deps = [ - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], + deps = ["@federated_language//federated_language"], ) py_library( @@ -206,11 +190,8 @@ py_library( "//tensorflow_federated/python/aggregators:sampling", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/metrics/aggregation_utils.py b/tensorflow_federated/python/learning/metrics/aggregation_utils.py index 1c074dc298..097defa2a7 100644 --- a/tensorflow_federated/python/learning/metrics/aggregation_utils.py +++ b/tensorflow_federated/python/learning/metrics/aggregation_utils.py @@ -17,17 +17,17 @@ import typing from typing import Union +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.metrics import types def check_finalizers_matches_unfinalized_metrics( metric_finalizers: types.MetricFinalizersType, - local_unfinalized_metrics_type: computation_types.StructWithPythonType, + local_unfinalized_metrics_type: federated_language.StructWithPythonType, ): """Verifies that compatibility of variables and finalizers. @@ -92,7 +92,7 @@ def check_metric_finalizers( def check_local_unfinalized_metrics_type( - local_unfinalized_metrics_type: computation_types.StructWithPythonType, + local_unfinalized_metrics_type: federated_language.StructWithPythonType, ): """Validates `local_unfinalized_metrics_type` raising error on failure. @@ -108,7 +108,7 @@ def check_local_unfinalized_metrics_type( # the error message has a better format (specifically, the expected type is # shown as `tff.types.StructWithPythonType` in the error message). if not isinstance( - local_unfinalized_metrics_type, computation_types.StructWithPythonType + local_unfinalized_metrics_type, federated_language.StructWithPythonType ): raise TypeError( 'Expected the input `local_unfinalized_metrics_type` to be a ' @@ -130,8 +130,8 @@ def build_finalizer_computation( types.MetricFinalizersType, types.FunctionalMetricFinalizersType, ], - local_unfinalized_metrics_type: computation_types.StructWithPythonType, -) -> computation_base.Computation: + local_unfinalized_metrics_type: federated_language.StructWithPythonType, +) -> federated_language.framework.Computation: """Builds computation for finalizing metrics.""" if callable(metric_finalizers): return tensorflow_computation.tf_computation( diff --git a/tensorflow_federated/python/learning/metrics/aggregation_utils_test.py b/tensorflow_federated/python/learning/metrics/aggregation_utils_test.py index 687cc89271..33ca779972 100644 --- a/tensorflow_federated/python/learning/metrics/aggregation_utils_test.py +++ b/tensorflow_federated/python/learning/metrics/aggregation_utils_test.py @@ -16,11 +16,11 @@ from typing import Any from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.metrics import aggregation_utils from tensorflow_federated.python.learning.metrics import keras_finalizer @@ -65,7 +65,7 @@ def test_invalid_finalizers_raises(self, metric_finalizers, expected_regex): class CheckUnfinalizedMetricsTypeTest(tf.test.TestCase, parameterized.TestCase): def test_valid_type_does_not_raise(self): - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( collections.OrderedDict( num_examples=np.int32, mean=[np.float32, np.float32] ), @@ -78,7 +78,7 @@ def test_valid_type_does_not_raise(self): @parameterized.named_parameters( ( 'struct_type', - computation_types.StructType([(None, np.int32)]), + federated_language.StructType([(None, np.int32)]), '`tff.types.StructWithPythonType`', ), ( @@ -88,7 +88,7 @@ def test_valid_type_does_not_raise(self): ), ( 'list_container', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [np.float32, np.float32], list ), 'Python container', @@ -109,7 +109,7 @@ def test_match_does_not_raise(self): metric_finalizers = collections.OrderedDict( num_examples=tf.function(func=lambda x: x), mean=_tf_mean ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( collections.OrderedDict( num_examples=np.int32, mean=[np.float32, np.float32] ), @@ -126,7 +126,7 @@ def test_match_does_not_raise(self): collections.OrderedDict( num_examples=tf.function(func=lambda x: x), mean=_tf_mean ), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( collections.OrderedDict(num_examples=np.int32), collections.OrderedDict, ), @@ -135,7 +135,7 @@ def test_match_does_not_raise(self): ( 'more_metrics_in_unfinalized_metrics_type', collections.OrderedDict(mean=_tf_mean), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( collections.OrderedDict( num_examples=np.int32, mean=[np.float32, np.float32] ), @@ -162,8 +162,8 @@ class CheckBuildFinalizerComputationTest( ( 'non_functional', collections.OrderedDict(accuracy=_tf_mean), - computation_types.StructWithPythonType( - [('accuracy', computation_types.TensorType(np.float32, (2,)))], + federated_language.StructWithPythonType( + [('accuracy', federated_language.TensorType(np.float32, (2,)))], collections.OrderedDict, ), collections.OrderedDict(accuracy=[0.2, 5.0]), @@ -172,12 +172,12 @@ class CheckBuildFinalizerComputationTest( ( 'functional', _test_functional_finalize_metrics, - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [( 'accuracy', [ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), + federated_language.TensorType(np.float32), ], )], collections.OrderedDict, diff --git a/tensorflow_federated/python/learning/metrics/aggregator.py b/tensorflow_federated/python/learning/metrics/aggregator.py index 4e1887ee01..b091e2ca07 100644 --- a/tensorflow_federated/python/learning/metrics/aggregator.py +++ b/tensorflow_federated/python/learning/metrics/aggregator.py @@ -16,12 +16,10 @@ import collections from typing import Optional, Union +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.learning.metrics import aggregation_utils from tensorflow_federated.python.learning.metrics import sampling_aggregation_factory @@ -39,9 +37,9 @@ def sum_then_finalize( types.FunctionalMetricFinalizersType, ], local_unfinalized_metrics_type: Optional[ - computation_types.StructWithPythonType + federated_language.StructWithPythonType ] = None, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates a TFF computation that aggregates metrics via `sum_then_finalize`. The returned federated TFF computation is a polymorphic computation that @@ -79,7 +77,7 @@ def sum_then_finalize( del local_unfinalized_metrics_type # Unused. aggregation_utils.check_metric_finalizers(metric_finalizers) - @federated_computation.federated_computation + @federated_language.federated_computation def aggregator_computation(client_local_unfinalized_metrics): local_unfinalized_metrics_type = ( client_local_unfinalized_metrics.type_signature.member @@ -94,7 +92,7 @@ def aggregator_computation(client_local_unfinalized_metrics): aggregation_utils.check_finalizers_matches_unfinalized_metrics( metric_finalizers, local_unfinalized_metrics_type ) - unfinalized_metrics_sum = intrinsics.federated_sum( + unfinalized_metrics_sum = federated_language.federated_sum( client_local_unfinalized_metrics ) @@ -113,7 +111,7 @@ def finalizer_computation(unfinalized_metrics): ) return finalized_metrics - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, unfinalized_metrics_sum ) @@ -133,12 +131,12 @@ def secure_sum_then_finalize( types.FunctionalMetricFinalizersType, ], local_unfinalized_metrics_type: Optional[ - computation_types.StructWithPythonType + federated_language.StructWithPythonType ] = None, metric_value_ranges: Optional[ sum_aggregation_factory.UserMetricValueRangeDict ] = None, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates a TFF computation that aggregates metrics using secure summation. The returned federated TFF computation is a polymorphic computation that @@ -261,7 +259,7 @@ def _create_secure_sum_process( assert not iterative_process.is_stateful(secure_sum_process) return secure_sum_process - @federated_computation.federated_computation + @federated_language.federated_computation def aggregator_computation(client_local_unfinalized_metrics): secure_sum_process = _create_secure_sum_process( client_local_unfinalized_metrics.type_signature.member @@ -292,7 +290,7 @@ def finalizer_computation(unfinalized_metrics, secure_sum_measurements): ) return finalized_metrics - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, (unfinalized_metrics, secure_sum_measurements) ) @@ -305,10 +303,10 @@ def finalize_then_sample( types.FunctionalMetricFinalizersType, ], local_unfinalized_metrics_type: Optional[ - computation_types.StructWithPythonType + federated_language.StructWithPythonType ] = None, sample_size: int = 100, -) -> computation_base.Computation: +) -> federated_language.framework.Computation: """Creates a TFF computation to aggregate metrics via `finalize_then_sample`. The returned federated TFF computation is a polymorphic computation that @@ -387,7 +385,7 @@ def finalize_then_sample( del local_unfinalized_metrics_type # Unused. aggregation_utils.check_metric_finalizers(metric_finalizers) - @federated_computation.federated_computation + @federated_language.federated_computation def aggregator_computation(client_local_unfinalized_metrics): local_unfinalized_metrics_type = ( client_local_unfinalized_metrics.type_signature.member diff --git a/tensorflow_federated/python/learning/metrics/aggregator_test.py b/tensorflow_federated/python/learning/metrics/aggregator_test.py index 6f0a3964a3..17ca74f5d3 100644 --- a/tensorflow_federated/python/learning/metrics/aggregator_test.py +++ b/tensorflow_federated/python/learning/metrics/aggregator_test.py @@ -16,14 +16,11 @@ from typing import Any from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.test import execution_contexts -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.metrics import keras_finalizer from tensorflow_federated.python.learning.metrics import sum_aggregation_factory @@ -33,9 +30,9 @@ accuracy=tf.function(func=lambda x: x[0] / x[1]) ) _UNUSED_UNFINALIZED_METRICS = collections.OrderedDict(accuracy=[1.0, 2.0]) -_UNUSED_UNFINALIZED_METRICS_TYPE = computation_types.StructWithPythonType( +_UNUSED_UNFINALIZED_METRICS_TYPE = federated_language.StructWithPythonType( [ - ('accuracy', computation_types.TensorType(np.float32, [2])), + ('accuracy', federated_language.TensorType(np.float32, [2])), ], container_type=collections.OrderedDict, ) @@ -85,13 +82,13 @@ def result(self): custom_sum=[1, 1, np.array([1, 1], np.int32)], ), ], - 'local_unfinalized_metrics_type': computation_types.to_type( + 'local_unfinalized_metrics_type': federated_language.to_type( collections.OrderedDict( accuracy=[np.float32, np.float32], custom_sum=[ np.int32, np.int32, - computation_types.TensorType(np.int32, [2]), + federated_language.TensorType(np.int32, [2]), ], ), ), @@ -133,13 +130,13 @@ def _test_finalize_metrics( custom_sum=[1, 1, np.array([1, 1], np.int32)], ), ], - 'local_unfinalized_metrics_type': computation_types.to_type( + 'local_unfinalized_metrics_type': federated_language.to_type( collections.OrderedDict( accuracy=[np.float32, np.float32], custom_sum=[ np.int32, np.int32, - computation_types.TensorType(np.int32, [2]), + federated_language.TensorType(np.int32, [2]), ], ) ), @@ -170,7 +167,7 @@ def _test_finalize_metrics( sum=collections.OrderedDict(count_1=1, count_2=1), ), ], - 'local_unfinalized_metrics_type': computation_types.to_type( + 'local_unfinalized_metrics_type': federated_language.to_type( collections.OrderedDict( divide=[np.float32, np.float32], sum=collections.OrderedDict(count_1=np.int32, count_2=np.int32), @@ -201,7 +198,7 @@ def _test_finalize_metrics( sum=collections.OrderedDict(count_1=1, count_2=1.0), ), ], - 'local_unfinalized_metrics_type': computation_types.to_type( + 'local_unfinalized_metrics_type': federated_language.to_type( collections.OrderedDict( divide=[np.float32, np.int32], sum=collections.OrderedDict(count_1=np.int32, count_2=np.float32), @@ -234,7 +231,7 @@ def _test_finalize_metrics( { 'testcase_name': 'unfinalized_metrics_not_structure_type', 'metric_finalizers': _UNUSED_METRICS_FINALIZERS, - 'local_unfinalized_metrics_type': computation_types.TensorType( + 'local_unfinalized_metrics_type': federated_language.TensorType( np.float32 ), 'error_type': TypeError, @@ -244,7 +241,7 @@ def _test_finalize_metrics( 'testcase_name': 'unfinalized_metrics_not_ordereddict', 'metric_finalizers': _UNUSED_METRICS_FINALIZERS, 'local_unfinalized_metrics_type': ( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [np.float32, np.float32], container_type=list, ) @@ -260,9 +257,9 @@ def _test_finalize_metrics( accuracy=tf.function(func=lambda x: x[0] / x[1]) ), 'local_unfinalized_metrics_type': ( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('loss', computation_types.TensorType(np.float32, [2])), + ('loss', federated_language.TensorType(np.float32, [2])), ], container_type=collections.OrderedDict, ) @@ -292,9 +289,9 @@ def test_returns_correct_results( local_unfinalized_metrics_type=local_unfinalized_metrics_type, ) - @federated_computation.federated_computation( - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ) ) def wrapped_aggregator_computation(unfinalized_metrics): @@ -318,10 +315,10 @@ def test_fails_with_invalid_inputs( # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( local_unfinalized_metrics_type, - placements.CLIENTS, + federated_language.CLIENTS, ) ) def _aggregator_computation(unfinalized_metrics): @@ -349,15 +346,15 @@ def test_default_value_ranges_returns_correct_results( metric_finalizers=metric_finalizers ) - @federated_computation.federated_computation( - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ) ) def aggregator_computation(unfinalized_metrics): return polymorphic_aggregator_computation(unfinalized_metrics) - static_assert.assert_not_contains_unsecure_aggregation( + federated_language.framework.assert_not_contains_unsecure_aggregation( aggregator_computation ) @@ -446,20 +443,20 @@ def test_user_value_ranges_returns_correct_results(self): # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( collections.OrderedDict( accuracy=[ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), + federated_language.TensorType(np.float32), ], custom_sum=[ - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), - computation_types.TensorType(dtype=np.int32, shape=(2,)), + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), + federated_language.TensorType(dtype=np.int32, shape=(2,)), ], ), - placements.CLIENTS, + federated_language.CLIENTS, ) ) def aggregator_computation(unfinalized_metrics): @@ -535,19 +532,19 @@ def test_user_value_ranges_mixed_dtypes_returns_correct_results(self): # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( collections.OrderedDict( divide=[ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.float32), + federated_language.TensorType(np.int32), ], sum=collections.OrderedDict( - count_1=computation_types.TensorType(np.int32), - count_2=computation_types.TensorType(np.float32), + count_1=federated_language.TensorType(np.int32), + count_2=federated_language.TensorType(np.float32), ), ), - placements.CLIENTS, + federated_language.CLIENTS, ) ) def aggregator_computation(unfinalized_metrics): @@ -618,12 +615,12 @@ def result(self): # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( collections.OrderedDict( - custom_sum=computation_types.TensorType(np.str_), + custom_sum=federated_language.TensorType(np.str_), ), - placements.CLIENTS, + federated_language.CLIENTS, ) ) def _aggregator_computation(unfinalized_metrics): @@ -641,15 +638,15 @@ def test_user_value_ranges_fails_not_2_tuple(self): # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( collections.OrderedDict( accuracy=[ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), + federated_language.TensorType(np.float32), ], ), - placements.CLIENTS, + federated_language.CLIENTS, ) ) def _aggregator_computation(unfinalized_metrics): @@ -677,10 +674,10 @@ def test_fails_with_invalid_inputs( # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( local_unfinalized_metrics_type, - placements.CLIENTS, + federated_language.CLIENTS, ) ) def _aggregator_computation(unfinalized_metrics): @@ -704,15 +701,15 @@ def test_fails_with_invalid_sample_size( # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( collections.OrderedDict( accuracy=[ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), + federated_language.TensorType(np.float32), ] ), - placements.CLIENTS, + federated_language.CLIENTS, ) ) def _aggregator_computation(unfinalized_metrics): @@ -734,10 +731,10 @@ def test_fails_with_invalid_inputs( # Concretize on a federated type with CLIENTS placement so that the method # invocation understands to interpret python lists of values as CLIENTS # placed values. - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( local_unfinalized_metrics_type, - placements.CLIENTS, + federated_language.CLIENTS, ) ) def _aggregator_computation(unfinalized_metrics): @@ -758,15 +755,15 @@ def test_returns_correct_num_samples( sample_size=sample_size, ) - @federated_computation.federated_computation( - computation_types.FederatedType( + @federated_language.federated_computation( + federated_language.FederatedType( collections.OrderedDict( accuracy=[ - computation_types.TensorType(np.float32), - computation_types.TensorType(np.float32), + federated_language.TensorType(np.float32), + federated_language.TensorType(np.float32), ], ), - placements.CLIENTS, + federated_language.CLIENTS, ) ) def aggregator_computation(unfinalized_metrics): diff --git a/tensorflow_federated/python/learning/metrics/keras_finalizer_test.py b/tensorflow_federated/python/learning/metrics/keras_finalizer_test.py index ed37c990f9..339c471654 100644 --- a/tensorflow_federated/python/learning/metrics/keras_finalizer_test.py +++ b/tensorflow_federated/python/learning/metrics/keras_finalizer_test.py @@ -13,12 +13,12 @@ # limitations under the License. from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.metrics import keras_finalizer @@ -102,7 +102,7 @@ class FinalizerTest(parameterized.TestCase, tf.test.TestCase): def test_keras_metric_finalizer_returns_correct_result(self, metric): # The unfinalized accuracy contains two tensors `total` and `count`. unfinalized_accuracy = [tf.constant(2.0), tf.constant(2.0)] - unfinalized_accuracy_type = computation_types.StructWithPythonType( + unfinalized_accuracy_type = federated_language.StructWithPythonType( [np.float32, np.float32], list ) finalizer_computation = wrap_tf_function_in_tff_tf_computation( @@ -120,18 +120,18 @@ def test_keras_metric_finalizer_returns_correct_result(self, metric): 'one_variable', CustomSumMetric(has_extra_variables=False), [tf.constant(1.0)], - computation_types.StructWithPythonType([np.float32], list), + federated_language.StructWithPythonType([np.float32], list), 1.0, ), ( 'three_variables', CustomSumMetric(has_extra_variables=True), [tf.constant(1.0), tf.constant(1.0), tf.constant([1.0, 1.0])], - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ np.float32, np.float32, - computation_types.TensorType(np.float32, [2]), + federated_language.TensorType(np.float32, [2]), ], list, ), @@ -192,7 +192,7 @@ def test_create_keras_metric_finalizer_fails_with_invalid_input( ), ( 'unmatched_shape', - [computation_types.TensorType(np.float32, shape=(2,)), np.float32], + [federated_language.TensorType(np.float32, shape=(2,)), np.float32], ValueError, r'found a `tf.Tensor` of shape \(2,\) and dtype tf.float32', ), diff --git a/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory.py b/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory.py index 385bad6934..585222fb4b 100644 --- a/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory.py +++ b/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory.py @@ -15,14 +15,12 @@ from typing import Union +import federated_language + from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import sampling from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.metrics import aggregation_utils @@ -110,7 +108,7 @@ def create( types.MetricFinalizersType, types.FunctionalMetricFinalizersType, ], - local_unfinalized_metrics_type: computation_types.StructWithPythonType, + local_unfinalized_metrics_type: federated_language.StructWithPythonType, ) -> aggregation_process.AggregationProcess: """Creates a `tff.templates.AggregationProcess` for metrics aggregation. @@ -158,7 +156,7 @@ def create( sample_size=self._sample_size, return_sampling_metadata=True ).create(local_finalized_metrics_type) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): @tensorflow_computation.tf_computation def create_initial_sample_state(): @@ -166,8 +164,8 @@ def create_initial_sample_state(): local_finalized_metrics_type ) - return intrinsics.federated_eval( - create_initial_sample_state, placements.SERVER + return federated_language.federated_eval( + create_initial_sample_state, federated_language.SERVER ) # We cannot directly use `init_fn.type_signature.result` as the server state @@ -176,14 +174,14 @@ def create_initial_sample_state(): # should use `None` to denote the shape that can change over rounds. state_type = sampling.build_reservoir_type(local_finalized_metrics_type) - @federated_computation.federated_computation( - computation_types.FederatedType(state_type, placements.SERVER), - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType(state_type, federated_language.SERVER), + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ), ) def next_fn(state, client_unfinalized_metrics): - local_finalized_metrics = intrinsics.federated_map( + local_finalized_metrics = federated_language.federated_map( local_finalize_computation, client_unfinalized_metrics ) current_round_sampling_output = sampling_process.next( @@ -192,7 +190,7 @@ def next_fn(state, client_unfinalized_metrics): merge_samples_computation = sampling.build_merge_samples_computation( value_type=local_finalized_metrics_type, sample_size=self._sample_size ) - new_state = intrinsics.federated_map( + new_state = federated_language.federated_map( merge_samples_computation, (state, current_round_sampling_output.result), ) @@ -200,7 +198,7 @@ def next_fn(state, client_unfinalized_metrics): total_rounds_samples = new_state['samples'] return measured_process.MeasuredProcessOutput( state=new_state, - result=intrinsics.federated_zip( + result=federated_language.federated_zip( (current_round_samples, total_rounds_samples) ), measurements=current_round_sampling_output.measurements, diff --git a/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory_test.py b/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory_test.py index 966274d7ba..50595a2d02 100644 --- a/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory_test.py +++ b/tensorflow_federated/python/learning/metrics/sampling_aggregation_factory_test.py @@ -16,12 +16,11 @@ from typing import Any from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.metrics import sampling_aggregation_factory from tensorflow_federated.python.learning.metrics import types @@ -62,15 +61,15 @@ def _create_random_unfinalized_metrics(seed: int): tf.random.stateless_uniform(shape=(2,), seed=[seed, seed]).numpy(), ], ) - metrics_type = computation_types.StructWithPythonType( + metrics_type = federated_language.StructWithPythonType( [ - ('accuracy', computation_types.TensorType(np.float32)), - ('loss', computation_types.TensorType(np.float32, (2,))), + ('accuracy', federated_language.TensorType(np.float32)), + ('loss', federated_language.TensorType(np.float32, (2,))), ( 'custom_sum', [ - computation_types.TensorType(np.float32, shape=(1,)), - computation_types.TensorType(np.float32, shape=(2,)), + federated_language.TensorType(np.float32, shape=(1,)), + federated_language.TensorType(np.float32, shape=(2,)), ], ), ], @@ -151,10 +150,12 @@ def test_create_process_fails_with_invalid_metric_finalizers( @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_create_process_fails_with_invalid_unfinalized_metrics_type( self, bad_unfinalized_metrics_type @@ -258,7 +259,7 @@ def test_finalize_then_sample_returns_expected_samples( ) def test_finalize_then_sample_returns_correct_measurements(self, sample_size): metric_finalizers = lambda x: x - unfinalized_metrics_type = computation_types.StructWithPythonType( + unfinalized_metrics_type = federated_language.StructWithPythonType( [('loss', np.float32)], collections.OrderedDict ) sampling_process = sampling_aggregation_factory.FinalizeThenSampleFactory( diff --git a/tensorflow_federated/python/learning/metrics/sum_aggregation_factory.py b/tensorflow_federated/python/learning/metrics/sum_aggregation_factory.py index b67739dc97..ae5ed16b79 100644 --- a/tensorflow_federated/python/learning/metrics/sum_aggregation_factory.py +++ b/tensorflow_federated/python/learning/metrics/sum_aggregation_factory.py @@ -20,6 +20,7 @@ import typing from typing import Any, Optional, Union +import federated_language import numpy as np import tensorflow as tf import tree @@ -31,10 +32,6 @@ from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process @@ -47,8 +44,8 @@ def _initialize_unfinalized_metrics_accumulators( ): """Initializes the unfinalized metrics accumulators.""" if initial_unfinalized_metrics is not None: - return intrinsics.federated_value( - initial_unfinalized_metrics, placements.SERVER + return federated_language.federated_value( + initial_unfinalized_metrics, federated_language.SERVER ) @tensorflow_computation.tf_computation @@ -58,7 +55,9 @@ def create_all_zero_state(): local_unfinalized_metrics_type, ) - return intrinsics.federated_eval(create_all_zero_state, placements.SERVER) + return federated_language.federated_eval( + create_all_zero_state, federated_language.SERVER + ) # TODO: b/227811468 - Support other inner aggregators for SecAgg and DP. @@ -141,7 +140,7 @@ def __init__( def create( self, - local_unfinalized_metrics_type: computation_types.StructWithPythonType, + local_unfinalized_metrics_type: federated_language.StructWithPythonType, ) -> aggregation_process.AggregationProcess: """Creates a `tff.templates.AggregationProcess` for metrics aggregation. @@ -175,22 +174,22 @@ def create( local_unfinalized_metrics_type ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): unfinalized_metrics_accumulators = ( _initialize_unfinalized_metrics_accumulators( local_unfinalized_metrics_type, self._initial_unfinalized_metrics ) ) - return intrinsics.federated_zip(( + return federated_language.federated_zip(( inner_summation_process.initialize(), unfinalized_metrics_accumulators, )) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ), ) def next_fn( @@ -214,7 +213,7 @@ def add_unfinalized_metrics( tf.add, unfinalized_metrics, summed_unfinalized_metrics ) - unfinalized_metrics_accumulators = intrinsics.federated_map( + unfinalized_metrics_accumulators = federated_language.federated_map( add_unfinalized_metrics, (unfinalized_metrics_accumulators, summed_unfinalized_metrics), ) @@ -223,18 +222,18 @@ def add_unfinalized_metrics( self._metric_finalizers, local_unfinalized_metrics_type ) - current_round_metrics = intrinsics.federated_map( + current_round_metrics = federated_language.federated_map( finalizer_computation, summed_unfinalized_metrics ) - total_rounds_metrics = intrinsics.federated_map( + total_rounds_metrics = federated_language.federated_map( finalizer_computation, unfinalized_metrics_accumulators ) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip( + state=federated_language.federated_zip( (inner_summation_state, unfinalized_metrics_accumulators) ), - result=intrinsics.federated_zip( + result=federated_language.federated_zip( (current_round_metrics, total_rounds_metrics) ), measurements=inner_summation_output.measurements, @@ -334,7 +333,7 @@ def _check_user_metric_value_range(value_range: UserMetricValueRange): # TODO: b/233054212 - re-enable lint def create_default_secure_sum_quantization_ranges( - local_unfinalized_metrics_type: computation_types.StructWithPythonType, + local_unfinalized_metrics_type: federated_language.StructWithPythonType, lower_bound: Union[int, float] = DEFAULT_FIXED_SECURE_LOWER_BOUND, upper_bound: Union[int, float] = DEFAULT_FIXED_SECURE_UPPER_BOUND, use_auto_tuned_bounds_for_float_values: Optional[bool] = True, @@ -393,7 +392,7 @@ def create_default_secure_sum_quantization_ranges( ) def create_default_range( - type_spec: computation_types.TensorType, + type_spec: federated_language.TensorType, ) -> MetricValueRange: if np.issubdtype(type_spec.dtype, np.floating): if use_auto_tuned_bounds_for_float_values: @@ -568,7 +567,7 @@ def __init__( def create( self, - local_unfinalized_metrics_type: computation_types.StructWithPythonType, + local_unfinalized_metrics_type: federated_language.StructWithPythonType, ) -> aggregation_process.AggregationProcess: """Creates an `AggregationProcess` for secure summation over metrics. @@ -689,25 +688,27 @@ def flatten_grouped_values(value_list_by_factory_key): aggregation_process_by_factory_key[ factory_key ] = aggregator_factories.get(value_range).create( - computation_types.to_type(tensor_type_list) + federated_language.to_type(tensor_type_list) ) # pytype: disable=attribute-error,wrong-arg-types - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): factory_init_states = collections.OrderedDict() for factory_key, process in aggregation_process_by_factory_key.items(): factory_init_states[factory_key] = process.initialize() - return intrinsics.federated_zip(factory_init_states) + return federated_language.federated_zip(factory_init_states) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ), ) def next_fn(state, client_local_unfinalized_metrics): - client_local_grouped_unfinalized_metrics = intrinsics.federated_map( - group_value_by_factory_key, client_local_unfinalized_metrics + client_local_grouped_unfinalized_metrics = ( + federated_language.federated_map( + group_value_by_factory_key, client_local_unfinalized_metrics + ) ) metrics_aggregation_output = collections.OrderedDict() new_state = collections.OrderedDict() @@ -718,7 +719,7 @@ def next_fn(state, client_local_unfinalized_metrics): ) new_state[factory_key] = metrics_aggregation_output[factory_key].state - metrics_aggregation_output = intrinsics.federated_zip( + metrics_aggregation_output = federated_language.federated_zip( metrics_aggregation_output ) @@ -742,12 +743,14 @@ def flatten_aggregation_output(grouped_aggregation_output): ) return unfinalized_metrics, secure_sum_measurements - unfinalized_metrics, secure_sum_measurements = intrinsics.federated_map( - flatten_aggregation_output, metrics_aggregation_output + unfinalized_metrics, secure_sum_measurements = ( + federated_language.federated_map( + flatten_aggregation_output, metrics_aggregation_output + ) ) return measured_process.MeasuredProcessOutput( - state=intrinsics.federated_zip(new_state), + state=federated_language.federated_zip(new_state), result=unfinalized_metrics, measurements=secure_sum_measurements, ) diff --git a/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py b/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py index e27925d337..b434bcd4a0 100644 --- a/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py +++ b/tensorflow_federated/python/learning/metrics/sum_aggregation_factory_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -22,16 +23,13 @@ from tensorflow_federated.python.aggregators import quantile_estimation from tensorflow_federated.python.core.backends.test import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import estimation_process from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning.metrics import sum_aggregation_factory # Convenience aliases. -TensorType = computation_types.TensorType +TensorType = federated_language.TensorType MetricRange = sum_aggregation_factory._MetricRange @@ -52,7 +50,7 @@ def _tensor_type_from_tensor_like(x): finalizer_type = tf.nest.map_structure( _tensor_type_from_tensor_like, finalized_metrics ) - return computation_types.StructWithPythonType( + return federated_language.StructWithPythonType( finalizer_type, collections.OrderedDict ) @@ -94,7 +92,7 @@ class SumThenFinalizeFactoryComputationTest( 'scalar_metric', collections.OrderedDict(num_examples=tf.function(func=lambda x: x)), collections.OrderedDict(num_examples=1.0), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [('num_examples', np.float32)], collections.OrderedDict ), ), @@ -102,7 +100,7 @@ class SumThenFinalizeFactoryComputationTest( 'non_scalar_metric', collections.OrderedDict(loss=_tf_mean), collections.OrderedDict(loss=[2.0, 1.0]), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [('loss', [np.float32, np.float32])], collections.OrderedDict ), ), @@ -112,7 +110,7 @@ class SumThenFinalizeFactoryComputationTest( lambda x: collections.OrderedDict(mean_loss=_tf_mean(x['loss'])) ), collections.OrderedDict(loss=[1.0, 2.0]), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [('loss', [np.float32, np.float32])], collections.OrderedDict ), ), @@ -132,10 +130,10 @@ def test_type_properties( process = aggregate_factory.create(local_unfinalized_metrics_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) - expected_state_type = computation_types.FederatedType( - ((), local_unfinalized_metrics_type), placements.SERVER + expected_state_type = federated_language.FederatedType( + ((), local_unfinalized_metrics_type), federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -147,15 +145,18 @@ def test_type_properties( finalized_metrics_type = _get_finalized_metrics_type( metric_finalizers, unfinalized_metrics ) - result_value_type = computation_types.FederatedType( - (finalized_metrics_type, finalized_metrics_type), placements.SERVER + result_value_type = federated_language.FederatedType( + (finalized_metrics_type, finalized_metrics_type), + federated_language.SERVER, ) - measurements_type = computation_types.FederatedType((), placements.SERVER) - expected_next_type = computation_types.FunctionType( + measurements_type = federated_language.FederatedType( + (), federated_language.SERVER + ) + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - unfinalized_metrics=computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + unfinalized_metrics=federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( @@ -200,13 +201,13 @@ def test_type_properties_with_inner_secure_sum_process( loss=[2.0, 1.0], custom_sum=[tf.constant(1.0), tf.constant([1.0, 1.0])], ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('num_examples', np.int32), ('loss', [np.float32, np.float32]), ( 'custom_sum', - [np.float32, computation_types.TensorType(np.float32, [2])], + [np.float32, federated_language.TensorType(np.float32, [2])], ), ], collections.OrderedDict, @@ -222,14 +223,14 @@ def test_type_properties_with_inner_secure_sum_process( secure_summation_process = secure_summation_factory.create( local_unfinalized_metrics_type ) - expected_state_type = computation_types.FederatedType( + expected_state_type = federated_language.FederatedType( ( secure_summation_process.initialize.type_signature.result.member, local_unfinalized_metrics_type, ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) self.assertTrue( @@ -241,18 +242,19 @@ def test_type_properties_with_inner_secure_sum_process( finalized_metrics_type = _get_finalized_metrics_type( metric_finalizers, local_unfinalized_metrics ) - result_value_type = computation_types.FederatedType( - (finalized_metrics_type, finalized_metrics_type), placements.SERVER + result_value_type = federated_language.FederatedType( + (finalized_metrics_type, finalized_metrics_type), + federated_language.SERVER, ) - measurements_type = computation_types.FederatedType( + measurements_type = federated_language.FederatedType( secure_summation_process.next.type_signature.result.measurements.member, - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, - unfinalized_metrics=computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + unfinalized_metrics=federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ), ), result=measured_process.MeasuredProcessOutput( @@ -262,7 +264,9 @@ def test_type_properties_with_inner_secure_sum_process( self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type) ) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) @parameterized.named_parameters( ('float', 1.0), @@ -278,10 +282,12 @@ def test_incorrect_finalizers_type_raises(self, bad_finalizers): @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_unfinalized_metrics_type_raises( self, bad_unfinalized_metrics_type @@ -304,10 +310,10 @@ def test_finalizers_and_unfinalized_metrics_type_mismatch_raises(self): aggregate_factory = sum_aggregation_factory.SumThenFinalizeFactory( metric_finalizers ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( collections.OrderedDict( - x=computation_types.TensorType(shape=[None, 2], dtype=np.float32), - y=computation_types.TensorType(shape=[None, 1], dtype=np.float32), + x=federated_language.TensorType(shape=[None, 2], dtype=np.float32), + y=federated_language.TensorType(shape=[None, 1], dtype=np.float32), ), container_type=collections.OrderedDict, ) @@ -320,7 +326,7 @@ def test_unfinalized_metrics_type_and_initial_values_mismatch_raises(self): metric_finalizers = collections.OrderedDict( num_examples=tf.function(func=lambda x: x) ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [('num_examples', np.float32)], collections.OrderedDict ) initial_unfinalized_metrics = collections.OrderedDict(num_examples=[1.0]) @@ -353,13 +359,13 @@ def test_sum_then_finalize_metrics(self): np.array([1.0, 1.0], np.float32), ], ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('num_examples', np.float32), ('loss', [np.float32, np.float32]), ( 'custom_sum', - [np.float32, computation_types.TensorType(np.float32, [2])], + [np.float32, federated_language.TensorType(np.float32, [2])], ), ], collections.OrderedDict, @@ -417,7 +423,7 @@ def test_sum_then_finalize_metrics_with_initial_values(self): local_unfinalized_metrics = collections.OrderedDict( num_examples=1.0, loss=[2.0, 1.0] ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('num_examples', np.float32), ('loss', [np.float32, np.float32]), @@ -472,13 +478,13 @@ def test_secure_sum_then_finalize_metrics(self): np.array([1.0, 1.0], np.float32), ], ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('num_examples', np.int32), ('loss', [np.float32, np.float32]), ( 'custom_sum', - [np.float32, computation_types.TensorType(np.float32, [2])], + [np.float32, federated_language.TensorType(np.float32, [2])], ), ], collections.OrderedDict, @@ -512,7 +518,9 @@ def test_secure_sum_then_finalize_metrics(self): client_data = [local_unfinalized_metrics, local_unfinalized_metrics] output = process.next(state, client_data) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) _, unfinalized_metrics_accumulators = output.state # Inital clippling bounds for float values are [-100.0, 100.0], metric @@ -580,19 +588,21 @@ def test_default_value_ranges_returns_correct_results(self): np.array([1.0, 1.0], np.float32), ], ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('num_examples', np.int32), ('loss', [np.float32, np.float32]), ( 'custom_sum', - [np.float32, computation_types.TensorType(np.float32, [2])], + [np.float32, federated_language.TensorType(np.float32, [2])], ), ], collections.OrderedDict, ) process = aggregate_factory.create(local_unfinalized_metrics_type) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) state = process.initialize() @@ -656,13 +666,13 @@ def test_user_value_ranges_returns_correct_results(self): np.array([1.0, 1.0], np.float32), ], ) - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('num_examples', np.int32), ('loss', [np.float32, np.float32]), ( 'custom_sum', - [np.float32, computation_types.TensorType(np.float32, [2])], + [np.float32, federated_language.TensorType(np.float32, [2])], ), ], collections.OrderedDict, @@ -678,7 +688,9 @@ def test_user_value_ranges_returns_correct_results(self): metric_value_ranges ) process = aggregate_factory.create(local_unfinalized_metrics_type) - static_assert.assert_not_contains_unsecure_aggregation(process.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + process.next + ) state = process.initialize() custom_float_factory_key = sum_aggregation_factory.create_factory_key( @@ -755,10 +767,12 @@ def test_user_value_ranges_returns_correct_results(self): @parameterized.named_parameters( ( 'federated_type', - computation_types.FederatedType(np.float32, placements.SERVER), + federated_language.FederatedType( + np.float32, federated_language.SERVER + ), ), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_unfinalized_metrics_type_raises( self, bad_unfinalized_metrics_type @@ -770,7 +784,7 @@ def test_incorrect_unfinalized_metrics_type_raises( secure_sum_factory.create(bad_unfinalized_metrics_type) def test_user_value_ranges_fails_invalid_dtype(self): - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [ ('custom_sum', [np.str_]), ], @@ -781,7 +795,7 @@ def test_user_value_ranges_fails_invalid_dtype(self): secure_sum_factory.create(local_unfinalized_metrics_type) def test_user_value_ranges_fails_not_2_tuple(self): - local_unfinalized_metrics_type = computation_types.StructWithPythonType( + local_unfinalized_metrics_type = federated_language.StructWithPythonType( [('accuracy', [np.float32, np.float32])], collections.OrderedDict, ) @@ -816,12 +830,12 @@ def assertAutoTunedBoundEqual(self, a, b, msg=None): ('int64', TensorType(np.int64, [3]), _DEFAULT_INT_RANGE), ( '', - computation_types.to_type([np.int64, np.float32]), + federated_language.to_type([np.int64, np.float32]), [_DEFAULT_INT_RANGE, _DEFAULT_AUTO_TUNED_FLOAT_RANGE], ), ( '>', - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=np.int64, b=collections.OrderedDict( @@ -859,12 +873,12 @@ def test_default_auto_tuned_range_construction( ('int64', TensorType(np.int64, [3]), _DEFAULT_INT_RANGE), ( '', - computation_types.to_type([np.int64, np.float32]), + federated_language.to_type([np.int64, np.float32]), [_DEFAULT_INT_RANGE, _DEFAULT_FIXED_FLOAT_RANGE], ), ( '>', - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=np.int64, b=collections.OrderedDict( @@ -915,14 +929,14 @@ def test_default_fixed_range_construction(self, type_spec, expected_range): ), ( '', - computation_types.to_type([np.int64, np.float32]), + federated_language.to_type([np.int64, np.float32]), 1, 5, [(1, 5), _DEFAULT_AUTO_TUNED_FLOAT_RANGE], ), ( '>', - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=np.int64, b=collections.OrderedDict( @@ -974,14 +988,14 @@ def test_user_supplied_range_using_default_auto_tuned_range( ), ( '', - computation_types.to_type([np.int64, np.float32]), + federated_language.to_type([np.int64, np.float32]), 1, 5, [(1, 5), (1.0, 5.0)], ), ( '>', - computation_types.to_type( + federated_language.to_type( collections.OrderedDict( a=np.int64, b=collections.OrderedDict( diff --git a/tensorflow_federated/python/learning/metrics/types.py b/tensorflow_federated/python/learning/metrics/types.py index e802591cba..2ecddad601 100644 --- a/tensorflow_federated/python/learning/metrics/types.py +++ b/tensorflow_federated/python/learning/metrics/types.py @@ -17,8 +17,7 @@ from collections.abc import Callable from typing import Any, Optional, Protocol -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types +import federated_language MetricFinalizersType = collections.OrderedDict[str, Callable[[Any], Any]] @@ -33,7 +32,7 @@ def __call__( self, metric_finalizers: MetricFinalizersType, local_unfinalized_metrics_type: Optional[ - computation_types.StructWithPythonType + federated_language.StructWithPythonType ] = None, - ) -> computation_base.Computation: + ) -> federated_language.framework.Computation: ... diff --git a/tensorflow_federated/python/learning/model_update_aggregator_test.py b/tensorflow_federated/python/learning/model_update_aggregator_test.py index 4e6641f519..ad570f50c4 100644 --- a/tensorflow_federated/python/learning/model_update_aggregator_test.py +++ b/tensorflow_federated/python/learning/model_update_aggregator_test.py @@ -15,23 +15,18 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.core.backends.mapreduce import form_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import iterative_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import debug_measurements from tensorflow_federated.python.learning import model_update_aggregator -_FLOAT_TYPE = computation_types.TensorType(np.float32) -_FLOAT_MATRIX_TYPE = computation_types.TensorType(np.float32, [200, 300]) +_FLOAT_TYPE = federated_language.TensorType(np.float32) +_FLOAT_MATRIX_TYPE = federated_language.TensorType(np.float32, [200, 300]) class ModelUpdateAggregatorTest(parameterized.TestCase): @@ -197,19 +192,25 @@ def test_weighted_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator( weighted=True ).create(_FLOAT_MATRIX_TYPE, _FLOAT_TYPE) - static_assert.assert_not_contains_unsecure_aggregation(aggregator.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + aggregator.next + ) def test_unweighted_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator( weighted=False ).create(_FLOAT_MATRIX_TYPE) - static_assert.assert_not_contains_unsecure_aggregation(aggregator.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + aggregator.next + ) def test_ddp_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.ddp_secure_aggregator( noise_multiplier=1e-2, expected_clients_per_round=10 ).create(_FLOAT_MATRIX_TYPE) - static_assert.assert_not_contains_unsecure_aggregation(aggregator.next) + federated_language.framework.assert_not_contains_unsecure_aggregation( + aggregator.next + ) @parameterized.named_parameters( ('zeroing_float', True, _FLOAT_TYPE), @@ -355,7 +356,7 @@ def _check_aggregated_scalar_count( ): aggregator = _mrfify_aggregator(aggregator) mrf = form_utils.get_map_reduce_form_for_computation(aggregator.next) - num_aggregated_scalars = type_analysis.count_tensors_in_type( + num_aggregated_scalars = federated_language.framework.count_tensors_in_type( mrf.work.type_signature.result )['parameters'] self.assertLess(num_aggregated_scalars, max_scalars) @@ -407,30 +408,30 @@ def _mrfify_aggregator(aggregator): if aggregator.is_weighted: - @federated_computation.federated_computation( + @federated_language.federated_computation( aggregator.next.type_signature.parameter[0], - computation_types.FederatedType( + federated_language.FederatedType( ( aggregator.next.type_signature.parameter[1].member, aggregator.next.type_signature.parameter[2].member, ), - placements.CLIENTS, + federated_language.CLIENTS, ), ) def next_fn(state, value): output = aggregator.next(state, value[0], value[1]) - return output.state, intrinsics.federated_zip( + return output.state, federated_language.federated_zip( (output.result, output.measurements) ) else: - @federated_computation.federated_computation( + @federated_language.federated_computation( aggregator.next.type_signature.parameter ) def next_fn(state, value): output = aggregator.next(state, value) - return output.state, intrinsics.federated_zip( + return output.state, federated_language.federated_zip( (output.result, output.measurements) ) diff --git a/tensorflow_federated/python/learning/models/BUILD b/tensorflow_federated/python/learning/models/BUILD index 62bee5b1cb..0028396d79 100644 --- a/tensorflow_federated/python/learning/models/BUILD +++ b/tensorflow_federated/python/learning/models/BUILD @@ -26,7 +26,7 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -65,11 +65,9 @@ py_test( ":variable", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:variable_utils", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/metrics:types", + "@federated_language//federated_language", ], ) @@ -81,11 +79,9 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/learning/metrics:counters", "//tensorflow_federated/python/learning/metrics:keras_finalizer", + "@federated_language//federated_language", ], ) @@ -101,12 +97,9 @@ py_test( "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/metrics:counters", + "@federated_language//federated_language", ], ) @@ -133,16 +126,13 @@ py_library( deps = [ ":functional", ":variable", - "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/computation:computation_impl", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -153,7 +143,7 @@ py_test( ":model_weights", ":variable", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -164,7 +154,7 @@ py_library( ":model_weights", ":variable", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -174,7 +164,7 @@ py_test( deps = [ ":model_weights", ":reconstruction_model", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -190,8 +180,7 @@ py_test( ":variable", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/models/functional_test.py b/tensorflow_federated/python/learning/models/functional_test.py index f56ae7ab66..cd98da4c17 100644 --- a/tensorflow_federated/python/learning/models/functional_test.py +++ b/tensorflow_federated/python/learning/models/functional_test.py @@ -16,14 +16,12 @@ import itertools from typing import Any, Optional +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import variable_utils -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.metrics import types from tensorflow_federated.python.learning.models import functional @@ -615,7 +613,7 @@ def test_tff_model_from_functional_federated_aggregate_metrics_succeeds(self): loss=[2.0, 4.0], mse=[2.0, 2.0], mae=[1.0, 6.0] ) metrics_aggregator = aggregator.sum_then_finalize - unfinalized_metrics_type = computation_types.to_type( + unfinalized_metrics_type = federated_language.to_type( collections.OrderedDict( loss=[np.float32, np.float32], mse=[np.float32, np.float32], @@ -631,9 +629,9 @@ def test_tff_model_from_functional_federated_aggregate_metrics_succeeds(self): # computation is later invoked on a list of values, TFF will teach each # element of the list as a single client value. This cannot be inferred from # the value itself. - @federated_computation.federated_computation( - computation_types.FederatedType( - unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + unfinalized_metrics_type, federated_language.CLIENTS ) ) def aggregate_metrics(metrics): diff --git a/tensorflow_federated/python/learning/models/keras_utils.py b/tensorflow_federated/python/learning/models/keras_utils.py index 6bf513e7d7..2b53e3c0e2 100644 --- a/tensorflow_federated/python/learning/models/keras_utils.py +++ b/tensorflow_federated/python/learning/models/keras_utils.py @@ -19,14 +19,12 @@ import warnings from absl import logging +import federated_language import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.learning.metrics import counters from tensorflow_federated.python.learning.metrics import keras_finalizer from tensorflow_federated.python.learning.models import variable @@ -161,8 +159,8 @@ def from_keras_model( 'information for both inputs to and predictions from the ' 'model. You passed input spec {}.'.format(input_spec) ) - if isinstance(input_spec, computation_types.Type): - if not type_analysis.is_structure_of_tensors(input_spec): + if isinstance(input_spec, federated_language.Type): + if not federated_language.framework.is_structure_of_tensors(input_spec): raise TypeError( 'Expected a `tff.Type` with all the leaf nodes being ' '`tff.TensorType`s, found an input spec {}.'.format(input_spec) @@ -301,7 +299,7 @@ def finalize_metric( for metric, (name, values) in zip(metrics, accumulators.items()) ]) - return intrinsics.federated_aggregate( + return federated_language.federated_aggregate( federated_values, zeros, accumulate, merge, report ) diff --git a/tensorflow_federated/python/learning/models/keras_utils_test.py b/tensorflow_federated/python/learning/models/keras_utils_test.py index 4cfad4edc3..8e22d36bea 100644 --- a/tensorflow_federated/python/learning/models/keras_utils_test.py +++ b/tensorflow_federated/python/learning/models/keras_utils_test.py @@ -20,16 +20,13 @@ import warnings from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.metrics import counters from tensorflow_federated.python.learning.models import keras_utils @@ -132,12 +129,12 @@ def test_convert_fails_on_non_keras_model(self): ), ( 'tff_struct_with_python_type', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( collections.OrderedDict( - x=computation_types.TensorType( + x=federated_language.TensorType( shape=[None, 1], dtype=np.float32 ), - y=computation_types.TensorType( + y=federated_language.TensorType( shape=[None, 1], dtype=np.float32 ), ), @@ -163,10 +160,10 @@ def test_input_spec_struct(self): keras_model = model_examples.build_linear_regression_keras_functional_model( feature_dims=1 ) - input_type = computation_types.StructWithPythonType( + input_type = federated_language.StructWithPythonType( [ - ('x', computation_types.TensorType(np.float32, [None, 1])), - ('y', computation_types.TensorType(np.float32, [None, 1])), + ('x', federated_language.TensorType(np.float32, [None, 1])), + ('y', federated_language.TensorType(np.float32, [None, 1])), ], collections.OrderedDict, ) @@ -248,15 +245,15 @@ def test_input_spec_batch_types_value_errors(self, input_spec): ), ( 'tff_type_not_tensortype', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.SequenceType( - computation_types.TensorType(np.float32) + federated_language.SequenceType( + federated_language.TensorType(np.float32) ), ), - ('y', computation_types.TensorType(np.float32, [None, 1])), + ('y', federated_language.TensorType(np.float32, [None, 1])), ], collections.OrderedDict, ), @@ -786,16 +783,17 @@ def _train_loop(): tff_model.metric_finalizers(), unfinalized_metrics_type ) - @federated_computation.federated_computation( - computation_types.FederatedType( - unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + unfinalized_metrics_type, federated_language.CLIENTS ) ) def wrapped_metrics_aggregation_computation(unfinalized_metrics): return metrics_aggregation_computation(unfinalized_metrics) self.assertIsInstance( - wrapped_metrics_aggregation_computation, computation_base.Computation + wrapped_metrics_aggregation_computation, + federated_language.framework.Computation, ) aggregated_outputs = wrapped_metrics_aggregation_computation( @@ -875,7 +873,7 @@ def finalizer_computation(unfinalized_metrics): ), ( 'tff_struct_with_python_type', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', @@ -895,7 +893,7 @@ def finalizer_computation(unfinalized_metrics): ), ( 'tff_struct_type', - computation_types.StructType([ + federated_language.StructType([ ( 'x', tensorflow_types.to_type( @@ -1045,7 +1043,7 @@ def call(self, y_true, y_pred): ), ( 'tff_struct_with_python_type', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', @@ -1065,7 +1063,7 @@ def call(self, y_true, y_pred): ), ( 'tff_struct_type', - computation_types.StructType([ + federated_language.StructType([ ( 'x', tensorflow_types.to_type( @@ -1334,9 +1332,9 @@ def update_state(self, y_true, y_pred, sample_weight=None): with self.assertRaisesRegex(TypeError, 'extra arguments'): - @federated_computation.federated_computation( - computation_types.FederatedType( - unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + unfinalized_metrics_type, federated_language.CLIENTS ) ) def _(unfinalized_metrics): @@ -1379,16 +1377,17 @@ def get_config(self): tff_model.metric_finalizers(), unfinalized_metrics_type ) - @federated_computation.federated_computation( - computation_types.FederatedType( - unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + unfinalized_metrics_type, federated_language.CLIENTS ) ) def wrapped_federated_metrics_aggregation(unfinalized_metrics): return federated_metrics_aggregation(unfinalized_metrics) self.assertIsInstance( - wrapped_federated_metrics_aggregation, computation_base.Computation + wrapped_federated_metrics_aggregation, + federated_language.framework.Computation, ) @parameterized.named_parameters( diff --git a/tensorflow_federated/python/learning/models/model_weights.py b/tensorflow_federated/python/learning/models/model_weights.py index 05afb1f206..a615476947 100644 --- a/tensorflow_federated/python/learning/models/model_weights.py +++ b/tensorflow_federated/python/learning/models/model_weights.py @@ -16,13 +16,13 @@ from collections.abc import Callable from typing import Any, NamedTuple, Union +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.models import variable @@ -105,8 +105,8 @@ def convert_variables_to_arrays(self) -> 'ModelWeights': def weights_type_from_model( - model: Union[variable.VariableModel, Callable[[], variable.VariableModel]] -) -> computation_types.StructType: + model: Union[variable.VariableModel, Callable[[], variable.VariableModel]], +) -> federated_language.StructType: """Creates a `tff.Type` from a `tff.learning.models.VariableModel` or callable that constructs a model. Args: @@ -125,12 +125,12 @@ def weights_type_from_model( py_typecheck.check_type(model, variable.VariableModel) model_weights = ModelWeights.from_model(model) - def _variable_to_type(x: tf.Variable) -> computation_types.Type: + def _variable_to_type(x: tf.Variable) -> federated_language.Type: return tensorflow_types.to_type((x.dtype, x.shape)) model_weights_type = tf.nest.map_structure(_variable_to_type, model_weights) # StructWithPythonType operates recursively, and will preserve the python type # information of model.trainable_variables and model.non_trainable_variables. - return computation_types.StructWithPythonType( + return federated_language.StructWithPythonType( model_weights_type, ModelWeights ) diff --git a/tensorflow_federated/python/learning/models/model_weights_test.py b/tensorflow_federated/python/learning/models/model_weights_test.py index 2fb765179f..f83b082f07 100644 --- a/tensorflow_federated/python/learning/models/model_weights_test.py +++ b/tensorflow_federated/python/learning/models/model_weights_test.py @@ -15,11 +15,11 @@ import collections from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import variable @@ -47,9 +47,9 @@ def local_variables(self): @property def input_spec(self): - return computation_types.StructType(( - computation_types.TensorSpec(tf.float32, [3]), - computation_types.TensorSpec(tf.float32, [1]), + return federated_language.StructType(( + federated_language.TensorSpec(tf.float32, [3]), + federated_language.TensorSpec(tf.float32, [1]), )) def predict_on_batch(self, batch_input, training=True): @@ -88,23 +88,23 @@ def test_returns_model_weights_for_model(self): model = TestModel() weights_type = model_weights.weights_type_from_model(model) self.assertEqual( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'trainable', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.float32, [3]), - computation_types.TensorType(np.float32, [1]), + federated_language.TensorType(np.float32, [3]), + federated_language.TensorType(np.float32, [1]), ], list, ), ), ( 'non_trainable', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ], list, ), @@ -118,23 +118,23 @@ def test_returns_model_weights_for_model(self): def test_returns_model_weights_for_model_callable(self): weights_type = model_weights.weights_type_from_model(TestModel) self.assertEqual( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'trainable', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.float32, [3]), - computation_types.TensorType(np.float32, [1]), + federated_language.TensorType(np.float32, [3]), + federated_language.TensorType(np.float32, [1]), ], list, ), ), ( 'non_trainable', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ], list, ), diff --git a/tensorflow_federated/python/learning/models/reconstruction_model.py b/tensorflow_federated/python/learning/models/reconstruction_model.py index 0ea0c5821e..030b6dd810 100644 --- a/tensorflow_federated/python/learning/models/reconstruction_model.py +++ b/tensorflow_federated/python/learning/models/reconstruction_model.py @@ -18,10 +18,10 @@ from collections.abc import Callable, Iterable, Mapping from typing import Any, NamedTuple, Optional +import federated_language import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import variable @@ -514,7 +514,7 @@ def __init__( global_non_trainable_variables: Iterable[tf.Variable], local_trainable_variables: Iterable[tf.Variable], local_non_trainable_variables: Iterable[tf.Variable], - input_spec: computation_types.Type, + input_spec: federated_language.Type, ): if not isinstance(inner_model, tf.keras.Model): raise TypeError( @@ -650,7 +650,7 @@ def forward_pass( def global_weights_type_from_model( model: ReconstructionModel, -) -> computation_types.StructType: +) -> federated_language.StructType: """Creates a `tff.Type` from a `tff.learning.models.ReconstructionModel`. Args: @@ -662,7 +662,7 @@ def global_weights_type_from_model( """ global_model_weights = ReconstructionModel.get_global_variables(model) - def _variable_to_type(x: tf.Variable) -> computation_types.Type: + def _variable_to_type(x: tf.Variable) -> federated_language.Type: return tensorflow_types.to_type((x.dtype, x.shape)) model_weights_type = tf.nest.map_structure( @@ -670,6 +670,6 @@ def _variable_to_type(x: tf.Variable) -> computation_types.Type: ) # StructWithPythonType operates recursively, and will preserve the python type # information of model.trainable_variables and model.non_trainable_variables. - return computation_types.StructWithPythonType( + return federated_language.StructWithPythonType( model_weights_type, model_weights.ModelWeights ) diff --git a/tensorflow_federated/python/learning/models/reconstruction_model_test.py b/tensorflow_federated/python/learning/models/reconstruction_model_test.py index c0a557cbba..08fd20975e 100644 --- a/tensorflow_federated/python/learning/models/reconstruction_model_test.py +++ b/tensorflow_federated/python/learning/models/reconstruction_model_test.py @@ -17,10 +17,10 @@ from typing import Optional from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import reconstruction_model @@ -181,18 +181,18 @@ def test_global_weights_type_with_global_and_local_layers(self): global_model_weights_type = ( reconstruction_model.global_weights_type_from_model(recon_model) ) - expected_trainable = computation_types.StructWithPythonType( + expected_trainable = federated_language.StructWithPythonType( [ - computation_types.TensorType(dtype=np.float32, shape=(5, 2)), - computation_types.TensorType(dtype=np.float32, shape=(2,)), + federated_language.TensorType(dtype=np.float32, shape=(5, 2)), + federated_language.TensorType(dtype=np.float32, shape=(2,)), ], list, ) - expected_non_trainable = computation_types.StructWithPythonType( + expected_non_trainable = federated_language.StructWithPythonType( [], list, ) - expected_type = computation_types.StructWithPythonType( + expected_type = federated_language.StructWithPythonType( [ ('trainable', expected_trainable), ('non_trainable', expected_non_trainable), @@ -262,20 +262,20 @@ def test_global_weights_type_with_only_global_layers(self): global_model_weights_type = ( reconstruction_model.global_weights_type_from_model(recon_model) ) - expected_trainable = computation_types.StructWithPythonType( + expected_trainable = federated_language.StructWithPythonType( [ - computation_types.TensorType(dtype=np.float32, shape=(5, 5)), - computation_types.TensorType(dtype=np.float32, shape=(5,)), - computation_types.TensorType(dtype=np.float32, shape=(5, 2)), - computation_types.TensorType(dtype=np.float32, shape=(2,)), + federated_language.TensorType(dtype=np.float32, shape=(5, 5)), + federated_language.TensorType(dtype=np.float32, shape=(5,)), + federated_language.TensorType(dtype=np.float32, shape=(5, 2)), + federated_language.TensorType(dtype=np.float32, shape=(2,)), ], list, ) - expected_non_trainable = computation_types.StructWithPythonType( + expected_non_trainable = federated_language.StructWithPythonType( [], list, ) - expected_type = computation_types.StructWithPythonType( + expected_type = federated_language.StructWithPythonType( [ ('trainable', expected_trainable), ('non_trainable', expected_non_trainable), diff --git a/tensorflow_federated/python/learning/models/serialization.py b/tensorflow_federated/python/learning/models/serialization.py index c85159ab5d..6217be79fe 100644 --- a/tensorflow_federated/python/learning/models/serialization.py +++ b/tensorflow_federated/python/learning/models/serialization.py @@ -16,18 +16,15 @@ import collections import functools +import federated_language +from federated_language.proto import computation_pb2 import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_impl -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization from tensorflow_federated.python.learning.models import functional from tensorflow_federated.python.learning.models import variable @@ -126,9 +123,9 @@ def deserialize_metric_finalizer(finalizer): computation_proto = computation_pb2.Computation.FromString( finalizer.read_value().numpy() ) - return computation_impl.ConcreteComputation( + return federated_language.framework.ConcreteComputation( computation_proto=computation_proto, - context_stack=context_stack_impl.context_stack, + context_stack=federated_language.framework.global_context_stack, ) return collections.OrderedDict( @@ -178,8 +175,8 @@ def _create_tensor_type(dtype, shape): concrete_fn.output_dtypes, concrete_fn.output_shapes, ) - result_type_spec = type_serialization.serialize_type( - computation_types.to_type(tensor_types) + result_type_spec = federated_language.framework.serialize_type( + federated_language.to_type(tensor_types) ) def flattened_output(*args, **kwargs): @@ -193,16 +190,16 @@ def flattened_output(*args, **kwargs): def _deserialize_type_spec(serialize_type_variable, python_container=None): """Deserialize a `tff.Type` protocol buffer into a python class instance.""" - type_spec = type_serialization.deserialize_type( + type_spec = federated_language.framework.deserialize_type( computation_pb2.Type.FromString( serialize_type_variable.read_value().numpy() ) ) if ( - isinstance(type_spec, computation_types.StructType) + isinstance(type_spec, federated_language.StructType) and python_container is not None ): - type_spec = computation_types.StructWithPythonType( + type_spec = federated_language.StructWithPythonType( structure.iter_elements(type_spec), python_container, ) @@ -339,8 +336,10 @@ def serialize_metric_finalizer(finalizer, metric_type): finalizer_computation = tensorflow_computation.tf_computation( finalizer, metric_type ) - computation_proto = computation_impl.ConcreteComputation.get_proto( - finalizer_computation + computation_proto = ( + federated_language.framework.ConcreteComputation.get_proto( + finalizer_computation + ) ) return tf.Variable( computation_proto.SerializeToString(deterministic=True), @@ -353,7 +352,7 @@ def type_for_normalized_tensor_value(value): tensor_spec = tf.TensorSpec.from_tensor(tensor) return tensorflow_types.to_type(tensor_spec) - return computation_types.to_type( + return federated_language.to_type( tf.nest.map_structure(type_for_normalized_tensor_value, values) ) @@ -368,7 +367,7 @@ def type_for_normalized_tensor_value(value): # Serialize the TFF values as string variables that contain the serialized # protos from the computation or the type. m.serialized_input_spec = tf.Variable( - type_serialization.serialize_type( + federated_language.framework.serialize_type( tensorflow_types.to_type(model.input_spec) ).SerializeToString(deterministic=True), trainable=False, @@ -463,7 +462,7 @@ def make_concrete_flat_predict_on_batch(training: bool): output_tensor_spec_structure = tf.nest.map_structure( tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs ) - result_type_spec = type_serialization.serialize_type( + result_type_spec = federated_language.framework.serialize_type( tensorflow_types.to_type(output_tensor_spec_structure) ) @@ -525,7 +524,7 @@ def make_concrete_flat_loss(): output_tensor_spec_structure = tf.nest.map_structure( tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs ) - result_type_spec = type_serialization.serialize_type( + result_type_spec = federated_language.framework.serialize_type( tensorflow_types.to_type(output_tensor_spec_structure) ) @@ -554,7 +553,7 @@ def flat_loss(output, label, sample_weight=None): # Serialize TFF values as string variables that contain the serialized # protos from the computation or the type. m.serialized_input_spec = tf.Variable( - type_serialization.serialize_type( + federated_language.framework.serialize_type( tensorflow_types.to_type(functional_model.input_spec) ).SerializeToString(deterministic=True), trainable=False, diff --git a/tensorflow_federated/python/learning/models/serialization_test.py b/tensorflow_federated/python/learning/models/serialization_test.py index 57d03e5459..3db2fa2f68 100644 --- a/tensorflow_federated/python/learning/models/serialization_test.py +++ b/tensorflow_federated/python/learning/models/serialization_test.py @@ -17,13 +17,12 @@ import os from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization from tensorflow_federated.python.learning.models import functional from tensorflow_federated.python.learning.models import keras_utils from tensorflow_federated.python.learning.models import model_examples @@ -32,9 +31,9 @@ from tensorflow_federated.python.learning.models import variable # Convenience aliases. -TensorType = computation_types.TensorType -StructType = computation_types.StructType -StructWithPythonType = computation_types.StructWithPythonType +TensorType = federated_language.TensorType +StructType = federated_language.StructType +StructWithPythonType = federated_language.StructWithPythonType class FlattenTest(tf.test.TestCase, parameterized.TestCase): @@ -142,7 +141,8 @@ def test_flatten_tf_function(self, fn, args, kwargs, expected_type_spec): # name here. self.assertEqual(type(concrete_fn).__name__, 'ConcreteFunction') self.assertProtoEquals( - type_spec, type_serialization.serialize_type(expected_type_spec) + type_spec, + federated_language.framework.serialize_type(expected_type_spec), ) @@ -227,9 +227,9 @@ def test_unflatten_tf_function( expected_python_container, ): type_spec_var = tf.Variable( - type_serialization.serialize_type(result_type_spec).SerializeToString( - deterministic=True - ) + federated_language.framework.serialize_type( + result_type_spec + ).SerializeToString(deterministic=True) ) @tf.function diff --git a/tensorflow_federated/python/learning/optimizers/BUILD b/tensorflow_federated/python/learning/optimizers/BUILD index a4698c2ce2..76fcb71280 100644 --- a/tensorflow_federated/python/learning/optimizers/BUILD +++ b/tensorflow_federated/python/learning/optimizers/BUILD @@ -125,10 +125,7 @@ py_test( "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/optimizers/integration_test.py b/tensorflow_federated/python/learning/optimizers/integration_test.py index a1210d5f18..1d2cfca89c 100644 --- a/tensorflow_federated/python/learning/optimizers/integration_test.py +++ b/tensorflow_federated/python/learning/optimizers/integration_test.py @@ -13,16 +13,13 @@ # limitations under the License. from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.optimizers import adagrad from tensorflow_federated.python.learning.optimizers import adam from tensorflow_federated.python.learning.optimizers import rmsprop @@ -93,26 +90,26 @@ def _run_in_federated_computation(optimizer, spec): lambda s: np.ones(s.shape, s.dtype.as_numpy_dtype()), spec ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_eval( + return federated_language.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(spec) ), - placements.SERVER, + federated_language.SERVER, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType( - tensorflow_types.to_type(spec), placements.SERVER + federated_language.FederatedType( + tensorflow_types.to_type(spec), federated_language.SERVER ), - computation_types.FederatedType( - tensorflow_types.to_type(spec), placements.SERVER + federated_language.FederatedType( + tensorflow_types.to_type(spec), federated_language.SERVER ), ) def next_fn(state, weights, gradients): - return intrinsics.federated_map( + return federated_language.federated_map( tensorflow_computation.tf_computation(optimizer.next), (state, weights, gradients), ) diff --git a/tensorflow_federated/python/learning/programs/BUILD b/tensorflow_federated/python/learning/programs/BUILD index c74de42e98..0369b43e75 100644 --- a/tensorflow_federated/python/learning/programs/BUILD +++ b/tensorflow_federated/python/learning/programs/BUILD @@ -29,12 +29,8 @@ py_library( deps = [ "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", - "//tensorflow_federated/python/program:federated_context", "//tensorflow_federated/python/program:file_program_state_manager", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", - "//tensorflow_federated/python/program:value_reference", + "@federated_language//federated_language", ], ) @@ -45,20 +41,13 @@ py_test( deps = [ ":evaluation_program_logic", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", - "//tensorflow_federated/python/program:federated_context", "//tensorflow_federated/python/program:file_program_state_manager", "//tensorflow_federated/python/program:native_platform", - "//tensorflow_federated/python/program:release_manager", + "@federated_language//federated_language", ], ) @@ -69,9 +58,7 @@ py_library( ":evaluation_program_logic", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", + "@federated_language//federated_language", ], ) @@ -82,11 +69,7 @@ py_library( ":evaluation_program_logic", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", - "//tensorflow_federated/python/program:federated_context", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", - "//tensorflow_federated/python/program:value_reference", + "@federated_language//federated_language", ], ) @@ -99,15 +82,10 @@ py_test( ":program_logic", ":training_program_logic", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_test_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", - "//tensorflow_federated/python/program:federated_context", "//tensorflow_federated/python/program:native_platform", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", + "@federated_language//federated_language", ], ) @@ -117,13 +95,9 @@ py_library( deps = [ ":evaluation_program_logic", ":program_logic", - "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", "//tensorflow_federated/python/program:structure_utils", - "//tensorflow_federated/python/program:value_reference", + "@federated_language//federated_language", ], ) @@ -133,13 +107,9 @@ py_test( deps = [ ":evaluation_program_logic", ":vizier_program_logic", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_test_utils", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:learning_process", - "//tensorflow_federated/python/program:data_source", "//tensorflow_federated/python/program:native_platform", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/programs/evaluation_program_logic.py b/tensorflow_federated/python/learning/programs/evaluation_program_logic.py index ed369a0fea..51904cf43e 100644 --- a/tensorflow_federated/python/learning/programs/evaluation_program_logic.py +++ b/tensorflow_federated/python/learning/programs/evaluation_program_logic.py @@ -67,16 +67,12 @@ from typing import Any, Optional from absl import logging as _logging +import federated_language import numpy as np from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source as data_source_lib -from tensorflow_federated.python.program import federated_context from tensorflow_federated.python.program import file_program_state_manager -from tensorflow_federated.python.program import program_state_manager -from tensorflow_federated.python.program import release_manager -from tensorflow_federated.python.program import value_reference # The prefix path for metrics exported to TensorBoard. This will group all the # metrics under tab with the same name. @@ -123,8 +119,8 @@ def __init__( self._lock = asyncio.Lock() # Lock for concurrency safety. async def load_latest( - self, structure: program_state_manager.ProgramStateStructure - ) -> program_state_manager.ProgramStateStructure: + self, structure: federated_language.program.ProgramStateStructure + ) -> federated_language.program.ProgramStateStructure: """Returns the latest program state. Args: @@ -139,7 +135,7 @@ async def load_latest( async def save( self, - program_state: program_state_manager.ProgramStateStructure, + program_state: federated_language.program.ProgramStateStructure, ) -> None: """Saves `program_state` and automatically advances the version number. @@ -182,10 +178,10 @@ class EvaluationManager: def __init__( self, - data_source: data_source_lib.FederatedDataSource, + data_source: federated_language.program.FederatedDataSource, aggregated_metrics_manager: Optional[ - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, int + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int ] ], create_state_manager_fn: Callable[ @@ -196,8 +192,8 @@ def __init__( tuple[ learning_process.LearningProcess, Optional[ - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, int + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int ] ], ], @@ -238,7 +234,7 @@ def __init__( self._pending_tasks: set[asyncio.Task] = set() @property - def data_source(self) -> data_source_lib.FederatedDataSource: + def data_source(self) -> federated_language.program.FederatedDataSource: """A data source used to create iterators each evaluation loop.""" return self._data_source @@ -246,7 +242,9 @@ def data_source(self) -> data_source_lib.FederatedDataSource: def aggregated_metrics_manager( self, ) -> Optional[ - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ]: """A manager for releasing metrics at the end of each evaluation loop.""" return self._aggregated_metrics_manager @@ -266,8 +264,8 @@ def create_process_fn( tuple[ learning_process.LearningProcess, Optional[ - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, int + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int ] ], ], @@ -353,8 +351,8 @@ def _start_evaluation_from_saved_model_weights( train_round_num: int, eval_process: learning_process.LearningProcess, per_round_metrics_manager: Optional[ - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, int + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int ] ], state_manager: file_program_state_manager.FileProgramStateManager, @@ -425,10 +423,10 @@ async def start_evaluation( evaluation_process, metrics_manager = self._create_evaluation_process_fn( evaluation_name ) - eval_state = await value_reference.materialize_value( + eval_state = await federated_language.program.materialize_value( evaluation_process.initialize() ) - eval_state = await value_reference.materialize_value( + eval_state = await federated_language.program.materialize_value( evaluation_process.set_model_weights(eval_state, model_weights) ) await state_manager.save(eval_state, version=0) @@ -563,14 +561,18 @@ async def _run_evaluation( state_manager: file_program_state_manager.FileProgramStateManager, evaluation_process: learning_process.LearningProcess, evaluation_name: str, - evaluation_data_source: data_source_lib.FederatedDataSource, + evaluation_data_source: federated_language.program.FederatedDataSource, evaluation_per_round_clients_number: int, evaluation_end_time: datetime.datetime, per_round_metrics_manager: Optional[ - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ], aggregated_metrics_manager: Optional[ - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ], ) -> None: """Runs evaluation for one training state. @@ -607,7 +609,7 @@ async def _run_evaluation( `tff.learning.templates.LearningProcessOutput` type. ValueError: If no previous state found for evaluation. """ - federated_context.check_in_federated_context() + federated_language.program.check_in_federated_context() evaluation_data_iterator = evaluation_data_source.iterator() @@ -621,7 +623,7 @@ async def invoke_evaluation(evaluation_state, eval_round_num): evaluation_data = evaluation_data_iterator.select( evaluation_per_round_clients_number ) - evaluation_result = await value_reference.materialize_value( + evaluation_result = await federated_language.program.materialize_value( evaluation_process.next(evaluation_state, evaluation_data) ) if isinstance(evaluation_result, learning_process.LearningProcessOutput): @@ -657,7 +659,9 @@ async def invoke_evaluation(evaluation_state, eval_round_num): # Read the initial state from the manager. If this is the first evaluation, # the zeroth version should contain the initial state. evaluation_state, version = await state_manager.load_latest( - await value_reference.materialize_value(evaluation_process.initialize()) + await federated_language.program.materialize_value( + evaluation_process.initialize() + ) ) if evaluation_state is None: raise ValueError( diff --git a/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py b/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py index ea4fedfac1..90ab886abc 100644 --- a/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py @@ -19,28 +19,21 @@ from unittest import mock from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.programs import evaluation_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source -from tensorflow_federated.python.program import federated_context from tensorflow_federated.python.program import file_program_state_manager from tensorflow_federated.python.program import native_platform -from tensorflow_federated.python.program import release_manager # Convenience aliases. -TensorType = computation_types.TensorType +TensorType = federated_language.TensorType class _NumpyMatcher: @@ -118,10 +111,10 @@ async def _value(): return native_platform.NativeValueReference(task, value_type) test_value = collections.OrderedDict( - a=awaitable_value('foo', computation_types.TensorType(np.str_)), + a=awaitable_value('foo', federated_language.TensorType(np.str_)), b=collections.OrderedDict( - x=awaitable_value('bar', computation_types.TensorType(np.str_)), - z=awaitable_value(1.0, computation_types.TensorType(np.float32)), + x=awaitable_value('bar', federated_language.TensorType(np.str_)), + z=awaitable_value(1.0, federated_language.TensorType(np.float32)), ), ) @@ -133,7 +126,7 @@ async def _value(): self.fail(f'Unexpected error raised: {e}') -def _create_test_context() -> federated_context.FederatedContext: +def _create_test_context() -> federated_language.program.FederatedContext: return native_platform.NativeFederatedContext( execution_contexts.create_async_local_cpp_execution_context() ) @@ -141,14 +134,18 @@ def _create_test_context() -> federated_context.FederatedContext: def _create_mock_datasource() -> mock.Mock: mock_datasource = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) def create_mock_iterator(*args, **kwargs) -> mock.Mock: del args # Unused del kwargs # Unused return mock.create_autospec( - data_source.FederatedDataSourceIterator, instance=True, spec_set=True + federated_language.program.FederatedDataSourceIterator, + instance=True, + spec_set=True, ) mock_datasource.iterator.side_effect = create_mock_iterator @@ -202,14 +199,18 @@ class EvaluationManagerTest(tf.test.TestCase, unittest.IsolatedAsyncioTestCase): def setUp(self): super().setUp() self.maxDiff = None - context_stack_impl.context_stack.set_default_context(_create_test_context()) + federated_language.framework.global_context_stack.set_default_context( + _create_test_context() + ) async def test_resume_nothing(self): mock_data_source = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_create_state_manager = mock.Mock() mock_meta_eval_manager = mock.create_autospec( @@ -241,10 +242,12 @@ async def test_resume_nothing(self): async def test_start_evaluations(self): mock_data_source = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) # Create a state manager with no previous evaluations. mock_meta_eval_manager = mock.create_autospec( @@ -284,10 +287,14 @@ async def test_start_evaluations(self): processes = [_create_mock_eval_process(), _create_mock_eval_process()] metrics_managers = [ mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, + instance=True, + spec_set=True, ), mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, + instance=True, + spec_set=True, ), ] mock_create_process_fn = mock.Mock( @@ -369,10 +376,12 @@ async def test_start_evaluations(self): async def test_record_finished_evaluations_removes_from_state(self): mock_data_source = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) # Create a state manager with two inflight evaluations. mock_meta_eval_manager = mock.create_autospec( @@ -421,10 +430,12 @@ async def test_record_finished_evaluations_removes_from_state(self): async def test_record_two_evaluations_finished_removes_from_state(self): mock_data_source = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) # Create a state manager with two inflight evaluations. mock_meta_eval_manager = mock.create_autospec( @@ -468,10 +479,12 @@ async def test_record_two_evaluations_finished_removes_from_state(self): async def test_resume_previous_evaluations(self): mock_data_source = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_create_state_manager = mock.Mock() mock_meta_eval_manager = mock.create_autospec( @@ -519,10 +532,14 @@ async def test_resume_previous_evaluations(self): ] + mock_resumed_eval_managers metrics_managers = [ mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, + instance=True, + spec_set=True, ), mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, + instance=True, + spec_set=True, ), ] mock_create_process_fn = mock.Mock( @@ -574,10 +591,12 @@ async def test_resume_previous_evaluations(self): async def test_failed_evaluation_raises(self): mock_data_source = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_create_state_manager = mock.Mock() mock_create_process_fn = mock.Mock() @@ -618,17 +637,20 @@ def _return_time() -> datetime.datetime: return self.time_after_end self._mock_return_time_fn = _return_time - context_stack_impl.context_stack.set_default_context(_create_test_context()) + federated_language.framework.global_context_stack.set_default_context( + _create_test_context() + ) async def test_invalid_process_rasies(self): - @federated_computation.federated_computation + + @federated_language.federated_computation def empty_initialize(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType( - computation_types.SequenceType(()), placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType( + federated_language.SequenceType(()), federated_language.CLIENTS ), ) def next_fn(state, inputs): @@ -653,10 +675,10 @@ def next_fn(state, inputs): evaluation_per_round_clients_number=num_clients, evaluation_end_time=self.end_time, per_round_metrics_manager=mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ), aggregated_metrics_manager=mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ), ) @@ -669,10 +691,10 @@ async def test_no_zero_state_raises(self): state_manager.load_latest.return_value = (None, 0) num_clients = 3 mock_per_round_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_aggregated_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) train_round_num = 1 with self.assertRaisesRegex( @@ -698,10 +720,10 @@ async def test_passed_end_time_runs_one_round(self): state_manager.load_latest.return_value = (eval_process.initialize(), 0) num_clients = 3 mock_per_round_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_aggregated_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) train_round_num = 1 await evaluation_program_logic._run_evaluation( @@ -744,10 +766,10 @@ async def test_future_end_time_runs_atleast_one_evaluation_round(self): state_manager.load_latest.return_value = (eval_process.initialize(), 0) num_clients = 3 mock_per_round_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_aggregated_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) train_round_num = 10 with mock.patch.object(datetime, 'datetime') as m_datetime: @@ -806,10 +828,10 @@ async def test_resume_evaluation_uses_correct_eval_round(self): ) num_clients = 3 mock_per_round_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_aggregated_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) train_round_num = 10 with mock.patch.object(datetime, 'datetime') as m_datetime: @@ -870,10 +892,10 @@ async def test_resume_evaluation_uses_correct_end_time(self): ) num_clients = 3 mock_per_round_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_aggregated_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) train_round_num = 10 await evaluation_program_logic._run_evaluation( diff --git a/tensorflow_federated/python/learning/programs/program_logic.py b/tensorflow_federated/python/learning/programs/program_logic.py index 9b80f3278e..6ec598ccb7 100644 --- a/tensorflow_federated/python/learning/programs/program_logic.py +++ b/tensorflow_federated/python/learning/programs/program_logic.py @@ -17,12 +17,11 @@ import typing from typing import Optional, Protocol, Union +import federated_language + from tensorflow_federated.python.learning.programs import evaluation_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source -from tensorflow_federated.python.program import program_state_manager as program_state_manager_lib -from tensorflow_federated.python.program import release_manager @typing.runtime_checkable @@ -39,16 +38,16 @@ async def __call__( *, train_process: learning_process.LearningProcess, initial_train_state: composers.LearningAlgorithmState, - train_data_source: data_source.FederatedDataSource, + train_data_source: federated_language.program.FederatedDataSource, train_per_round_clients: int, train_total_rounds: int, - program_state_manager: program_state_manager_lib.ProgramStateManager, - model_output_manager: release_manager.ReleaseManager[ - release_manager.ReleasableStructure, str + program_state_manager: federated_language.program.ProgramStateManager, + model_output_manager: federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, str ], train_metrics_manager: Optional[ - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, int + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int ] ] = None, evaluation_manager: Optional[evaluation_program_logic.EvaluationManager], diff --git a/tensorflow_federated/python/learning/programs/training_program_logic.py b/tensorflow_federated/python/learning/programs/training_program_logic.py index ef367d6c14..d2a44b4d74 100644 --- a/tensorflow_federated/python/learning/programs/training_program_logic.py +++ b/tensorflow_federated/python/learning/programs/training_program_logic.py @@ -28,15 +28,11 @@ from typing import Any, NamedTuple, Optional, Union from absl import logging +import federated_language from tensorflow_federated.python.learning.programs import evaluation_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source -from tensorflow_federated.python.program import federated_context -from tensorflow_federated.python.program import program_state_manager as program_state_manager_lib -from tensorflow_federated.python.program import release_manager -from tensorflow_federated.python.program import value_reference _PROGRAM_METRICS_KEY = 'program_metrics' @@ -59,7 +55,9 @@ class ProgramState(NamedTuple): state: composers.LearningAlgorithmState round_number: int next_evaluation_timestamp_seconds: Optional[int] - data_iterator: Optional[data_source.FederatedDataSourceIterator] + data_iterator: Optional[ + federated_language.program.FederatedDataSourceIterator + ] class TaskManager: @@ -98,7 +96,7 @@ def _add_program_metrics( metrics: Mapping[str, Any], round_end_time: datetime.datetime, num_retries: int = 0, -) -> release_manager.ReleasableStructure: +) -> federated_language.program.ReleasableStructure: """Adds program performance metrics to the metrics.""" if _PROGRAM_METRICS_KEY in metrics: raise ValueError( @@ -118,18 +116,20 @@ async def train_model( *, train_process: learning_process.LearningProcess, initial_train_state: Optional[composers.LearningAlgorithmState] = None, - train_data_source: data_source.FederatedDataSource, + train_data_source: federated_language.program.FederatedDataSource, train_per_round_clients: int, train_total_rounds: int, should_retry_round: Optional[ Callable[[learning_process.LearningProcessOutput], bool] ] = None, - program_state_manager: program_state_manager_lib.ProgramStateManager, - model_output_manager: release_manager.ReleaseManager[ - release_manager.ReleasableStructure, str + program_state_manager: federated_language.program.ProgramStateManager, + model_output_manager: federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, str ], train_metrics_manager: Optional[ - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ] = None, evaluation_manager: Optional[evaluation_program_logic.EvaluationManager], evaluation_periodicity: Union[int, datetime.timedelta], @@ -191,7 +191,7 @@ async def train_model( Raises: ValueError: If the train state is None. """ - federated_context.check_in_federated_context() + federated_language.program.check_in_federated_context() # A list of pending tasks (evaluation, value releases, etc) that we must await # before shutting down the program. @@ -208,7 +208,7 @@ async def train_model( # previous run, this program state can be used to restore the execution of # this program logic and skip unnecessary steps. if initial_train_state is None: - initial_train_state = await value_reference.materialize_value( + initial_train_state = await federated_language.program.materialize_value( train_process.initialize() ) train_state = initial_train_state @@ -314,7 +314,7 @@ def should_evaluate_round( round_participants_data = train_data_iterator.select( train_per_round_clients ) - train_result = await value_reference.materialize_value( + train_result = await federated_language.program.materialize_value( train_process.next(train_state, round_participants_data) ) if should_retry_round is not None and should_retry_round(train_result): diff --git a/tensorflow_federated/python/learning/programs/training_program_logic_test.py b/tensorflow_federated/python/learning/programs/training_program_logic_test.py index 303c52b19e..03f76d6e22 100644 --- a/tensorflow_federated/python/learning/programs/training_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/training_program_logic_test.py @@ -20,25 +20,20 @@ from unittest import mock from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning.programs import evaluation_program_logic from tensorflow_federated.python.learning.programs import program_logic from tensorflow_federated.python.learning.programs import training_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source -from tensorflow_federated.python.program import federated_context from tensorflow_federated.python.program import native_platform -from tensorflow_federated.python.program import program_state_manager -from tensorflow_federated.python.program import release_manager # Convenience aliases. ProgramState = training_program_logic.ProgramState -TensorType = computation_types.TensorType +TensorType = federated_language.TensorType class TaskManagerTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): @@ -86,13 +81,15 @@ async def task(): self.assertEmpty(task_manager._pending_tasks) -def _create_test_context() -> federated_context.FederatedContext: +def _create_test_context() -> federated_language.program.FederatedContext: return native_platform.NativeFederatedContext( execution_contexts.create_async_local_cpp_execution_context() ) -class _FakeDataSourceIterator(data_source.FederatedDataSourceIterator): +class _FakeDataSourceIterator( + federated_language.program.FederatedDataSourceIterator +): """A fake iterator that tracks the number of times it has been selected.""" def __init__(self, round_num: int): @@ -111,23 +108,25 @@ def to_bytes(self) -> bytes: return self._round_num.to_bytes(4, 'big') @property - def federated_type(self) -> computation_types.FederatedType: - return computation_types.FederatedType( - computation_types.SequenceType(element=TensorType(np.int32)), - computation_types.Placement.SERVER, + def federated_type(self) -> federated_language.FederatedType: + return federated_language.FederatedType( + federated_language.SequenceType(element=TensorType(np.int32)), + federated_language.Placement.SERVER, ) def _assert_data_source_iterators_equal( - iterator1: data_source.FederatedDataSourceIterator, - iterator2: data_source.FederatedDataSourceIterator, + iterator1: federated_language.program.FederatedDataSourceIterator, + iterator2: federated_language.program.FederatedDataSourceIterator, ): return iterator1.to_bytes() == iterator2.to_bytes() def _create_mock_datasource() -> mock.Mock: mock_datasource = mock.create_autospec( - data_source.FederatedDataSource, instance=True, spec_set=True + federated_language.program.FederatedDataSource, + instance=True, + spec_set=True, ) mock_datasource.iterator.return_value = _FakeDataSourceIterator(0) return mock_datasource @@ -155,10 +154,10 @@ def _create_mock_train_process() -> mock.Mock: ), ) type(mock_process.next).type_signature = mock.PropertyMock( - return_value=computation_types.FunctionType( + return_value=federated_language.FunctionType( parameter=( empty_state, - computation_types.SequenceType(element=TensorType(np.float32)), + federated_language.SequenceType(element=TensorType(np.float32)), ), result=mock_process.next.return_value, ) @@ -201,7 +200,7 @@ def test_is_train_model_program_logic(self): training_program_logic.train_model, program_logic.TrainModelProgramLogic ) - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_integration_runs_11_training_rounds_two_eval_rounds_from_scratch( self, ): @@ -212,15 +211,17 @@ async def test_integration_runs_11_training_rounds_two_eval_rounds_from_scratch( # Create a mock state manager that returns no previous state, starting # training from scratch. mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True, spec_set=True + federated_language.program.ProgramStateManager, + instance=True, + spec_set=True, ) mock_program_state_manager.load_latest.side_effect = [(None, 0)] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) for manager in (mock_model_output_manager, mock_train_metrics_manager): manager.release.return_value = None @@ -347,7 +348,7 @@ async def return_round_num() -> int: ) mock_evaluation_manager.wait_for_evaluations_to_finish.assert_called_once() - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_integration_runs_training_rounds_evaluates_on_time(self): train_num_clients = 5 training_rounds = 5 @@ -356,15 +357,17 @@ async def test_integration_runs_training_rounds_evaluates_on_time(self): # Create a mock state manager that returns no previous state, starting # training from scratch. mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True, spec_set=True + federated_language.program.ProgramStateManager, + instance=True, + spec_set=True, ) mock_program_state_manager.load_latest.side_effect = [(None, 0)] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) for manager in (mock_model_output_manager, mock_train_metrics_manager): manager.release.return_value = None @@ -514,7 +517,7 @@ async def return_round_num() -> None: ) mock_evaluation_manager.wait_for_evaluations_to_finish.assert_called_once() - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_integration_runs_5_training_rounds_no_eval_manager(self): train_num_clients = 5 training_rounds = 5 @@ -523,15 +526,17 @@ async def test_integration_runs_5_training_rounds_no_eval_manager(self): # Create a mock state manager that returns no previous state, starting # training from scratch. mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True, spec_set=True + federated_language.program.ProgramStateManager, + instance=True, + spec_set=True, ) mock_program_state_manager.load_latest.side_effect = [(None, 0)] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) for manager in (mock_model_output_manager, mock_train_metrics_manager): manager.release.return_value = None @@ -625,7 +630,7 @@ async def test_integration_runs_5_training_rounds_no_eval_manager(self): mock_model_output_manager.release.call_args_list, ) - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_program_state_manager_work_with_initial_state(self): initial_train_state = composers.LearningAlgorithmState( global_model_weights=(), @@ -638,15 +643,17 @@ async def test_program_state_manager_work_with_initial_state(self): # Create a mock state manager that returns no previous state, starting # training from scratch. mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True, spec_set=True + federated_language.program.ProgramStateManager, + instance=True, + spec_set=True, ) mock_program_state_manager.load_latest.side_effect = [(None, 0)] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) for manager in (mock_model_output_manager, mock_train_metrics_manager): manager.release.return_value = None @@ -706,7 +713,7 @@ async def test_program_state_manager_work_with_initial_state(self): _FakeDataSourceIterator(0), ) - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_resumes_from_previous_version_10_runs_one_round(self): train_num_clients = 5 training_rounds = 11 @@ -718,7 +725,7 @@ async def test_resumes_from_previous_version_10_runs_one_round(self): # (one before the last requested round). training_state = training_process.initialize() mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True + federated_language.program.ProgramStateManager, instance=True ) mock_program_state_manager.load_latest.side_effect = [( ProgramState( @@ -731,10 +738,10 @@ async def test_resumes_from_previous_version_10_runs_one_round(self): )] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) for manager in (mock_model_output_manager, mock_train_metrics_manager): manager.release.return_value = None @@ -829,7 +836,7 @@ async def test_resumes_from_previous_version_10_runs_one_round(self): mock_evaluation_manager.resume_from_previous_state.assert_called_once() mock_evaluation_manager.wait_for_evaluations_to_finish.assert_called_once() - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_resumes_from_previous_runs_no_train_rounds(self): train_num_clients = 5 training_rounds = 10 @@ -841,7 +848,7 @@ async def test_resumes_from_previous_runs_no_train_rounds(self): # completed the entire training process. training_state = training_process.initialize() mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True + federated_language.program.ProgramStateManager, instance=True ) mock_program_state_manager.load_latest.side_effect = [( ProgramState( @@ -854,10 +861,10 @@ async def test_resumes_from_previous_runs_no_train_rounds(self): )] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True + federated_language.program.ReleaseManager, instance=True ) # Run an evaluation every round, but we assert none were run because no @@ -917,7 +924,7 @@ async def test_resumes_from_previous_runs_no_train_rounds(self): mock_evaluation_manager.resume_from_previous_state.assert_called_once() mock_evaluation_manager.wait_for_evaluations_to_finish.assert_called_once() - @context_stack_test_utils.with_context(_create_test_context) + @federated_language.framework.with_context(_create_test_context) async def test_retries_one_round( self, ): @@ -955,15 +962,17 @@ def test_should_retry_round(train_result): # Create a mock state manager that returns no previous state, starting # training from scratch. mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, instance=True, spec_set=True + federated_language.program.ProgramStateManager, + instance=True, + spec_set=True, ) mock_program_state_manager.load_latest.side_effect = [(None, 0)] mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, instance=True, spec_set=True + federated_language.program.ReleaseManager, instance=True, spec_set=True ) for manager in (mock_model_output_manager, mock_train_metrics_manager): manager.release.return_value = None diff --git a/tensorflow_federated/python/learning/programs/vizier_program_logic.py b/tensorflow_federated/python/learning/programs/vizier_program_logic.py index 8e4a3e1e57..0747bea9fd 100644 --- a/tensorflow_federated/python/learning/programs/vizier_program_logic.py +++ b/tensorflow_federated/python/learning/programs/vizier_program_logic.py @@ -18,25 +18,23 @@ import datetime from typing import Optional, Protocol, Union +import federated_language from vizier import pyvizier from vizier.client import client_abc -from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.learning.programs import evaluation_program_logic from tensorflow_federated.python.learning.programs import program_logic from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source -from tensorflow_federated.python.program import program_state_manager as program_state_manager_lib -from tensorflow_federated.python.program import release_manager from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference class IntReleaseManagerFactory(Protocol): def __call__( self, trial: client_abc.TrialInterface - ) -> release_manager.ReleaseManager[release_manager.ReleasableStructure, int]: + ) -> federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ]: pass @@ -44,7 +42,9 @@ class StrReleaseManagerFactory(Protocol): def __call__( self, trial: client_abc.TrialInterface - ) -> release_manager.ReleaseManager[release_manager.ReleasableStructure, str]: + ) -> federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, str + ]: pass @@ -52,7 +52,7 @@ class ProgramStateManagerFactory(Protocol): def __call__( self, trial: client_abc.TrialInterface - ) -> program_state_manager_lib.ProgramStateManager: + ) -> federated_language.program.ProgramStateManager: pass @@ -74,12 +74,12 @@ def __call__( async def _create_measurement( - value: value_reference.MaterializableStructure, + value: federated_language.program.MaterializableStructure, steps: int, creation_time: datetime.datetime, ) -> pyvizier.Measurement: """Creates a Vizier Measurement for the given `value`.""" - materialized_value = await value_reference.materialize_value(value) + materialized_value = await federated_language.program.materialize_value(value) flattened_value = structure_utils.flatten_with_name(materialized_value) metrics = {k: v for k, v in flattened_value} elapsed_time = datetime.datetime.now().astimezone() - creation_time @@ -90,7 +90,9 @@ async def _create_measurement( class _IntermediateMeasurementReleaseManager( - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ): """Releases metrics as a trial's intermediate measurement.""" @@ -98,7 +100,7 @@ def __init__(self, trial: client_abc.TrialInterface): self._trial = trial async def release( - self, value: release_manager.ReleasableStructure, key: int + self, value: federated_language.program.ReleasableStructure, key: int ) -> None: creation_time = self._trial.materialize().creation_time measurement = await _create_measurement( @@ -116,7 +118,9 @@ def __eq__(self, other: object) -> bool: class _FinalMeasurementReleaseManager( - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ): """Releases metrics as a trial's final measurement.""" @@ -124,7 +128,7 @@ def __init__(self, trial: client_abc.TrialInterface): self._trial = trial async def release( - self, value: release_manager.ReleasableStructure, key: int + self, value: federated_language.program.ReleasableStructure, key: int ) -> None: creation_time = self._trial.materialize().creation_time measurement = await _create_measurement( @@ -145,10 +149,10 @@ async def train_model_with_vizier( study: client_abc.StudyInterface, total_trials: int, num_parallel_trials: int = 1, - update_hparams: computation_base.Computation, + update_hparams: federated_language.framework.Computation, train_model_program_logic: program_logic.TrainModelProgramLogic, train_process_factory: TrainProcessFactory, - train_data_source: data_source.FederatedDataSource, + train_data_source: federated_language.program.FederatedDataSource, total_rounds: int, num_clients: int, program_state_manager_factory: ProgramStateManagerFactory, @@ -216,20 +220,24 @@ async def train_model_with_vizier_one_worker(vizier_worker_name): ) if train_metrics_manager_factory is not None: manager = train_metrics_manager_factory(trial) - train_metrics_manager = release_manager.GroupingReleaseManager([ - manager, - intermediate_release_manager, - ]) + train_metrics_manager = ( + federated_language.program.GroupingReleaseManager([ + manager, + intermediate_release_manager, + ]) + ) else: train_metrics_manager = intermediate_release_manager final_release_manager = _FinalMeasurementReleaseManager(trial) evaluation_manager = evaluation_manager_factory(trial) if evaluation_manager.aggregated_metrics_manager is not None: - aggregated_metrics_manager = release_manager.GroupingReleaseManager([ - evaluation_manager.aggregated_metrics_manager, - final_release_manager, - ]) + aggregated_metrics_manager = ( + federated_language.program.GroupingReleaseManager([ + evaluation_manager.aggregated_metrics_manager, + final_release_manager, + ]) + ) else: aggregated_metrics_manager = final_release_manager evaluation_manager = evaluation_program_logic.EvaluationManager( diff --git a/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py b/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py index 62fe1b306e..e27626c87c 100644 --- a/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/vizier_program_logic_test.py @@ -18,18 +18,14 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language from vizier.client import client_abc -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_test_utils from tensorflow_federated.python.learning.programs import evaluation_program_logic from tensorflow_federated.python.learning.programs import vizier_program_logic from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import learning_process -from tensorflow_federated.python.program import data_source from tensorflow_federated.python.program import native_platform -from tensorflow_federated.python.program import program_state_manager -from tensorflow_federated.python.program import release_manager def _create_mock_context() -> mock.Mock: @@ -68,7 +64,7 @@ class TrainModelWithVizierTest( parameterized.TestCase, unittest.IsolatedAsyncioTestCase ): - @context_stack_test_utils.with_context(_create_mock_context) + @federated_language.framework.with_context(_create_mock_context) async def test_calls_program_components(self): total_trials = 10 total_rounds = 10 @@ -99,34 +95,38 @@ def suggest(*args, **kwargs): mock_study.trials().get.side_effect = lambda: suggested_trials mock_update_hparams = mock.create_autospec( - computation_base.Computation, spec_set=True + federated_language.framework.Computation, spec_set=True ) mock_train_model_program_logic = mock.AsyncMock() mock_train_process = _create_mock_train_process() mock_train_process_factory = mock.Mock(return_value=mock_train_process) mock_train_data_source = mock.create_autospec( - data_source.FederatedDataSource, spec_set=True, instance=True + federated_language.program.FederatedDataSource, + spec_set=True, + instance=True, ) mock_program_state_manager = mock.create_autospec( - program_state_manager.ProgramStateManager, spec_set=True, instance=True + federated_language.program.ProgramStateManager, + spec_set=True, + instance=True, ) mock_program_state_manager_factory = mock.Mock( return_value=mock_program_state_manager ) mock_model_output_manager = mock.create_autospec( - release_manager.ReleaseManager, spec_set=True, instance=True + federated_language.program.ReleaseManager, spec_set=True, instance=True ) mock_model_output_manager_factory = mock.Mock( return_value=mock_model_output_manager ) mock_train_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, spec_set=True, instance=True + federated_language.program.ReleaseManager, spec_set=True, instance=True ) mock_train_metrics_manager_factory = mock.Mock( return_value=mock_train_metrics_manager ) mock_aggregated_metrics_manager = mock.create_autospec( - release_manager.ReleaseManager, spec_set=True, instance=True + federated_language.program.ReleaseManager, spec_set=True, instance=True ) mock_evaluation_manager = mock.create_autospec( evaluation_program_logic.EvaluationManager, spec_set=True, instance=True @@ -225,7 +225,8 @@ def suggest(*args, **kwargs): if mock_train_metrics_manager_factory is not None: self.assertIsInstance( - actual_train_metrics_manager, release_manager.GroupingReleaseManager + actual_train_metrics_manager, + federated_language.program.GroupingReleaseManager, ) self.assertIn( expected_intermediate_measurement_release_manager, @@ -251,7 +252,7 @@ def suggest(*args, **kwargs): if mock_evaluation_manager.aggregated_metrics_manager is not None: self.assertIsInstance( actual_aggregated_metrics_manager, - release_manager.GroupingReleaseManager, + federated_language.program.GroupingReleaseManager, ) self.assertIn( expected_final_measurement_release_manager, diff --git a/tensorflow_federated/python/learning/templates/BUILD b/tensorflow_federated/python/learning/templates/BUILD index 023d61edd1..72b6317ef2 100644 --- a/tensorflow_federated/python/learning/templates/BUILD +++ b/tensorflow_federated/python/learning/templates/BUILD @@ -39,15 +39,11 @@ py_library( ":finalizers", "//tensorflow_federated/python/core/environments/tensorflow_backend:type_conversions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:tensor_utils", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", + "@federated_language//federated_language", ], ) @@ -58,13 +54,11 @@ py_test( ":apply_optimizer_finalizer", ":finalizers", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) @@ -74,12 +68,9 @@ py_library( deps = [ ":hparams_base", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -91,13 +82,10 @@ py_test( ":client_works", ":hparams_base", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/models:model_weights", + "@federated_language//federated_language", ], ) @@ -115,16 +103,12 @@ py_library( "//tensorflow_federated/python/aggregators:mean", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:variable", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) @@ -146,16 +130,13 @@ py_test( "//tensorflow_federated/python/common_libs:structure", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning/models:keras_utils", "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) @@ -165,13 +146,9 @@ py_library( deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -183,12 +160,9 @@ py_test( ":distributors", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -198,13 +172,9 @@ py_library( deps = [ ":hparams_base", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", + "@federated_language//federated_language", ], ) @@ -216,13 +186,10 @@ py_test( ":finalizers", ":hparams_base", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/models:model_weights", + "@federated_language//federated_language", ], ) @@ -232,8 +199,7 @@ py_library( deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -243,8 +209,7 @@ py_test( deps = [ ":hparams_base", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_test_utils", + "@federated_language//federated_language", ], ) @@ -254,11 +219,9 @@ py_library( deps = [ ":hparams_base", "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:iterative_process", + "@federated_language//federated_language", ], ) @@ -269,11 +232,8 @@ py_test( deps = [ ":learning_process", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", + "@federated_language//federated_language", ], ) @@ -285,10 +245,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", @@ -299,6 +255,7 @@ py_library( "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:variable", "//tensorflow_federated/python/learning/optimizers:optimizer", + "@federated_language//federated_language", ], ) @@ -313,11 +270,6 @@ py_test( "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", @@ -327,6 +279,7 @@ py_test( "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) @@ -338,10 +291,6 @@ py_library( "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", @@ -352,6 +301,7 @@ py_library( "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:variable", "//tensorflow_federated/python/learning/optimizers:optimizer", + "@federated_language//federated_language", ], ) @@ -366,10 +316,6 @@ py_test( "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", @@ -378,17 +324,14 @@ py_test( "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", + "@federated_language//federated_language", ], ) py_library( name = "type_checks", srcs = ["type_checks.py"], - deps = [ - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - ], + deps = ["@federated_language//federated_language"], ) py_test( @@ -396,7 +339,6 @@ py_test( srcs = ["type_checks_test.py"], deps = [ ":type_checks", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py index 6b6e850778..d93e8a816b 100644 --- a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py +++ b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py @@ -17,15 +17,11 @@ from collections.abc import Callable from typing import Any, Optional, Union +import federated_language import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import tensor_utils from tensorflow_federated.python.learning.models import model_weights @@ -62,7 +58,7 @@ def reject_non_finite_update( def _build_tff_optimizer_initialize_and_next( - model_weights_type: computation_types.Type, + model_weights_type: federated_language.Type, optimizer: optimizer_base.Optimizer, should_reject_update: Callable[ [Any, Any], tuple[Union[bool, tf.Tensor], Optional[_MeasurementsType]] @@ -100,7 +96,7 @@ def next_fn(optimizer_state, trainable_weights, update): def build_apply_optimizer_finalizer( optimizer_fn: optimizer_base.Optimizer, - model_weights_type: computation_types.StructType, + model_weights_type: federated_language.StructType, should_reject_update: Callable[ [Any, Any], tuple[Union[bool, tf.Tensor], Optional[_MeasurementsType]] ] = reject_non_finite_update, @@ -141,9 +137,13 @@ def build_apply_optimizer_finalizer( Python container, or contains a `tff.types.FederatedType`. """ if ( - not isinstance(model_weights_type, computation_types.StructWithPythonType) + not isinstance( + model_weights_type, federated_language.StructWithPythonType + ) or model_weights_type.python_container != model_weights.ModelWeights - or type_analysis.contains_federated_types(model_weights_type) + or federated_language.framework.contains_federated_types( + model_weights_type + ) ): raise TypeError( 'Provided value_type must be a tff.types.StructType with its python ' @@ -155,22 +155,26 @@ def build_apply_optimizer_finalizer( model_weights_type, optimizer_fn, should_reject_update ) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_eval(init_tf, placements.SERVER) + return federated_language.federated_eval(init_tf, federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(model_weights_type, placements.SERVER), - computation_types.FederatedType( - model_weights_type.trainable, placements.SERVER + federated_language.FederatedType( + model_weights_type, federated_language.SERVER + ), + federated_language.FederatedType( + model_weights_type.trainable, federated_language.SERVER ), ) def next_fn(state, weights, update): optimizer_state, new_trainable_weights, measurements = ( - intrinsics.federated_map(next_tf, (state, weights.trainable, update)) + federated_language.federated_map( + next_tf, (state, weights.trainable, update) + ) ) - new_weights = intrinsics.federated_zip( + new_weights = federated_language.federated_zip( model_weights.ModelWeights(new_trainable_weights, weights.non_trainable) ) return measured_process.MeasuredProcessOutput( diff --git a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py index 040efe9e52..5a0da01819 100644 --- a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py +++ b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py @@ -16,13 +16,11 @@ import copy from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base @@ -30,13 +28,15 @@ from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import finalizers -SERVER_FLOAT = computation_types.FederatedType(np.float32, placements.SERVER) +SERVER_FLOAT = federated_language.FederatedType( + np.float32, federated_language.SERVER +) _MODEL_WEIGHTS_SPEC = model_weights.ModelWeights( trainable=(np.float32,), non_trainable=(np.float32,) ) -_MODEL_WEIGHTS_TYPE = computation_types.to_type(_MODEL_WEIGHTS_SPEC) -_SERVER_MODEL_WEIGHTS_TYPE = computation_types.FederatedType( - _MODEL_WEIGHTS_TYPE, placements.SERVER +_MODEL_WEIGHTS_TYPE = federated_language.to_type(_MODEL_WEIGHTS_SPEC) +_SERVER_MODEL_WEIGHTS_TYPE = federated_language.FederatedType( + _MODEL_WEIGHTS_TYPE, federated_language.SERVER ) MeasuredProcessOutput = measured_process.MeasuredProcessOutput @@ -82,18 +82,18 @@ def test_initialize_has_expected_type_with_tff_optimizer(self): sgdm.build_sgdm(1.0), _MODEL_WEIGHTS_TYPE ) - expected_state_type = computation_types.FederatedType( - computation_types.to_type( + expected_state_type = federated_language.FederatedType( + federated_language.to_type( collections.OrderedDict( [(optimizer_base.LEARNING_RATE_KEY, np.float32)] ) ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( finalizer.initialize.type_signature, expected_initialize_type ) @@ -103,22 +103,23 @@ def test_next_has_expected_type_with_tff_optimizer(self): ) expected_param_weights_type = _SERVER_MODEL_WEIGHTS_TYPE - expected_param_update_type = computation_types.FederatedType( - _MODEL_WEIGHTS_TYPE.trainable, placements.SERVER + expected_param_update_type = federated_language.FederatedType( + _MODEL_WEIGHTS_TYPE.trainable, federated_language.SERVER ) expected_result_type = _SERVER_MODEL_WEIGHTS_TYPE - expected_state_type = computation_types.FederatedType( - computation_types.to_type( + expected_state_type = federated_language.FederatedType( + federated_language.to_type( collections.OrderedDict( [(optimizer_base.LEARNING_RATE_KEY, np.float32)] ) ), - placements.SERVER, + federated_language.SERVER, ) - expected_measurements_type = computation_types.FederatedType( - collections.OrderedDict(update_non_finite=np.int32), placements.SERVER + expected_measurements_type = federated_language.FederatedType( + collections.OrderedDict(update_non_finite=np.int32), + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_weights_type, @@ -130,7 +131,7 @@ def test_next_has_expected_type_with_tff_optimizer(self): expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( finalizer.next.type_signature, expected_next_type ) @@ -143,10 +144,10 @@ def test_get_hparams_has_expected_type_with_tff_optimizer(self): [(optimizer_base.LEARNING_RATE_KEY, np.float32)] ) expected_hparams_type = expected_state_type - expected_get_hparams_type = computation_types.FunctionType( + expected_get_hparams_type = federated_language.FunctionType( parameter=expected_state_type, result=expected_hparams_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( finalizer.get_hparams.type_signature, expected_get_hparams_type ) @@ -159,36 +160,36 @@ def test_set_hparams_has_expected_type_with_tff_optimizer(self): [(optimizer_base.LEARNING_RATE_KEY, np.float32)] ) expected_hparams_type = expected_state_type - expected_set_hparams_type = computation_types.FunctionType( - parameter=computation_types.StructType( + expected_set_hparams_type = federated_language.FunctionType( + parameter=federated_language.StructType( [('state', expected_state_type), ('hparams', expected_hparams_type)] ), result=expected_state_type, ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( finalizer.set_hparams.type_signature, expected_set_hparams_type ) @parameterized.named_parameters( - ('not_struct', computation_types.TensorType(np.float32)), + ('not_struct', federated_language.TensorType(np.float32)), ('federated_type', _SERVER_MODEL_WEIGHTS_TYPE), ( 'model_weights_of_federated_types', - computation_types.to_type( + federated_language.to_type( model_weights.ModelWeights(SERVER_FLOAT, SERVER_FLOAT) ), ), ( 'not_model_weights', - computation_types.to_type((np.float32, np.float32)), + federated_language.to_type((np.float32, np.float32)), ), ( 'function_type', - computation_types.FunctionType(None, _SERVER_MODEL_WEIGHTS_TYPE), + federated_language.FunctionType(None, _SERVER_MODEL_WEIGHTS_TYPE), ), ( 'sequence_type', - computation_types.SequenceType(_SERVER_MODEL_WEIGHTS_TYPE.member), + federated_language.SequenceType(_SERVER_MODEL_WEIGHTS_TYPE.member), ), ) def test_incorrect_value_type_raises(self, bad_type): diff --git a/tensorflow_federated/python/learning/templates/client_works.py b/tensorflow_federated/python/learning/templates/client_works.py index ab2249db0e..d9625f2ddf 100644 --- a/tensorflow_federated/python/learning/templates/client_works.py +++ b/tensorflow_federated/python/learning/templates/client_works.py @@ -15,11 +15,9 @@ from typing import Any, NamedTuple, Optional +import federated_language + from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.templates import hparams_base @@ -45,11 +43,13 @@ class ClientResultTypeError(TypeError): # TODO: b/240314933 - Move this (or refactor this) to a more general location. -def _is_allowed_client_data_type(type_spec: computation_types.Type) -> bool: +def _is_allowed_client_data_type(type_spec: federated_language.Type) -> bool: """Determines whether a given type is a (possibly nested) sequence type.""" - if isinstance(type_spec, computation_types.SequenceType): - return type_analysis.is_tensorflow_compatible_type(type_spec.element) - elif isinstance(type_spec, computation_types.StructType): + if isinstance(type_spec, federated_language.SequenceType): + return federated_language.framework.is_tensorflow_compatible_type( + type_spec.element + ) + elif isinstance(type_spec, federated_language.StructType): return all( _is_allowed_client_data_type(element_type) for element_type in type_spec.children() @@ -59,9 +59,11 @@ def _is_allowed_client_data_type(type_spec: computation_types.Type) -> bool: # TODO: b/240314933 - Move this (or refactor this) to a more general location. -def _type_check_initialize_fn(initialize_fn: computation_base.Computation): +def _type_check_initialize_fn( + initialize_fn: federated_language.framework.Computation, +): if not isinstance( - initialize_fn.type_signature.result, computation_types.FederatedType + initialize_fn.type_signature.result, federated_language.FederatedType ): raise errors.TemplateNotFederatedError( 'Provided `initialize_fn` must return a federated type, but found ' @@ -69,7 +71,7 @@ def _type_check_initialize_fn(initialize_fn: computation_base.Computation): 'see a collection of federated types, try wrapping the returned ' 'value in `tff.federated_zip` before returning.' ) - if initialize_fn.type_signature.result.placement != placements.SERVER: # pytype: disable=attribute-error + if initialize_fn.type_signature.result.placement != federated_language.SERVER: # pytype: disable=attribute-error raise errors.TemplatePlacementError( 'The state controlled by a `ClientWorkProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' @@ -77,21 +79,21 @@ def _type_check_initialize_fn(initialize_fn: computation_base.Computation): # TODO: b/240314933 - Move this (or refactor this) to a more general location. -def _check_next_fn_is_federated(next_fn: computation_base.Computation): +def _check_next_fn_is_federated( + next_fn: federated_language.framework.Computation, +): """Checks that a given `next_fn` has federated inputs and outputs.""" next_types = structure.flatten( next_fn.type_signature.parameter ) + structure.flatten(next_fn.type_signature.result) if not all( - [isinstance(t, computation_types.FederatedType) for t in next_types] + [isinstance(t, federated_language.FederatedType) for t in next_types] ): - offending_types = '\n- '.join( - [ - t - for t in next_types - if not isinstance(t, computation_types.FederatedType) - ] - ) + offending_types = '\n- '.join([ + t + for t in next_types + if not isinstance(t, federated_language.FederatedType) + ]) raise errors.TemplateNotFederatedError( 'Provided `next_fn` must be a *federated* computation, that is, ' 'operate on `tff.FederatedType`s, but found\n' @@ -101,10 +103,12 @@ def _check_next_fn_is_federated(next_fn: computation_base.Computation): # TODO: b/240314933 - Move this (or refactor this) to a more general location. -def _type_check_next_fn_parameters(next_fn: computation_base.Computation): +def _type_check_next_fn_parameters( + next_fn: federated_language.framework.Computation, +): """Validates the input types of `next_fn` in a `ClientWorkProcess`.""" next_fn_param = next_fn.type_signature.parameter - if not isinstance(next_fn_param, computation_types.StructType): + if not isinstance(next_fn_param, federated_language.StructType): raise errors.TemplateNextFnNumArgsError( 'The `next_fn` must have exactly three input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' @@ -117,12 +121,12 @@ def _type_check_next_fn_parameters(next_fn: computation_base.Computation): ) second_next_param = next_fn_param[1] client_data_param = next_fn_param[2] - if second_next_param.placement != placements.CLIENTS: + if second_next_param.placement != federated_language.CLIENTS: raise errors.TemplatePlacementError( 'The second input argument of `next_fn` must be placed at CLIENTS ' f'but found {second_next_param}.' ) - if client_data_param.placement != placements.CLIENTS: + if client_data_param.placement != federated_language.CLIENTS: raise errors.TemplatePlacementError( 'The third input argument of `next_fn` must be placed at CLIENTS ' f'but found {client_data_param}.' @@ -135,12 +139,14 @@ def _type_check_next_fn_parameters(next_fn: computation_base.Computation): # TODO: b/240314933 - Move this (or refactor this) to a more general location. -def _type_check_next_fn_result(next_fn: computation_base.Computation): +def _type_check_next_fn_result( + next_fn: federated_language.framework.Computation, +): """Validates the output types of `next_fn` in a `ClientWorkProcess`.""" next_fn_result = next_fn.type_signature.result if ( - not isinstance(next_fn_result.result, computation_types.FederatedType) # pytype: disable=attribute-error - or next_fn_result.result.placement is not placements.CLIENTS # pytype: disable=attribute-error + not isinstance(next_fn_result.result, federated_language.FederatedType) # pytype: disable=attribute-error + or next_fn_result.result.placement is not federated_language.CLIENTS # pytype: disable=attribute-error ): raise errors.TemplatePlacementError( 'The "result" attribute of the return type of `next_fn` must be ' @@ -149,7 +155,7 @@ def _type_check_next_fn_result(next_fn: computation_base.Computation): if ( not isinstance( next_fn_result.result.member, # pytype: disable=attribute-error - computation_types.StructWithPythonType, + federated_language.StructWithPythonType, ) or next_fn_result.result.member.python_container is not ClientResult # pytype: disable=attribute-error ): @@ -157,7 +163,7 @@ def _type_check_next_fn_result(next_fn: computation_base.Computation): 'The "result" attribute of the return type of `next_fn` must have ' f'the `ClientResult` container, but found {next_fn_result.result}.' # pytype: disable=attribute-error ) - if next_fn_result.measurements.placement != placements.SERVER: # pytype: disable=attribute-error + if next_fn_result.measurements.placement != federated_language.SERVER: # pytype: disable=attribute-error raise errors.TemplatePlacementError( 'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.' # pytype: disable=attribute-error @@ -178,11 +184,11 @@ class ClientWorkProcess(measured_process.MeasuredProcess): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, *, - get_hparams_fn: Optional[computation_base.Computation] = None, - set_hparams_fn: Optional[computation_base.Computation] = None, + get_hparams_fn: Optional[federated_language.framework.Computation] = None, + set_hparams_fn: Optional[federated_language.framework.Computation] = None, ): """Initializes a `ClientWorkProcess`. @@ -268,9 +274,9 @@ def __init__( self._set_hparams_fn = set_hparams_fn @property - def get_hparams(self) -> computation_base.Computation: + def get_hparams(self) -> federated_language.framework.Computation: return self._get_hparams_fn # pytype: disable=attribute-error @property - def set_hparams(self) -> computation_base.Computation: + def set_hparams(self) -> federated_language.framework.Computation: return self._set_hparams_fn # pytype: disable=attribute-error diff --git a/tensorflow_federated/python/learning/templates/client_works_test.py b/tensorflow_federated/python/learning/templates/client_works_test.py index 4b9fecdcfb..34e0d6174d 100644 --- a/tensorflow_federated/python/learning/templates/client_works_test.py +++ b/tensorflow_federated/python/learning/templates/client_works_test.py @@ -15,33 +15,38 @@ import collections from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.learning.templates import hparams_base -SERVER_INT = computation_types.FederatedType(np.int32, placements.SERVER) -SERVER_FLOAT = computation_types.FederatedType(np.float32, placements.SERVER) -CLIENTS_FLOAT_SEQUENCE = computation_types.FederatedType( - computation_types.SequenceType(np.float32), placements.CLIENTS +SERVER_INT = federated_language.FederatedType( + np.int32, federated_language.SERVER ) -CLIENTS_FLOAT = computation_types.FederatedType(np.float32, placements.CLIENTS) -CLIENTS_INT = computation_types.FederatedType(np.int32, placements.CLIENTS) -MODEL_WEIGHTS_TYPE = computation_types.FederatedType( - computation_types.to_type( +SERVER_FLOAT = federated_language.FederatedType( + np.float32, federated_language.SERVER +) +CLIENTS_FLOAT_SEQUENCE = federated_language.FederatedType( + federated_language.SequenceType(np.float32), federated_language.CLIENTS +) +CLIENTS_FLOAT = federated_language.FederatedType( + np.float32, federated_language.CLIENTS +) +CLIENTS_INT = federated_language.FederatedType( + np.int32, federated_language.CLIENTS +) +MODEL_WEIGHTS_TYPE = federated_language.FederatedType( + federated_language.to_type( model_weights.ModelWeights(np.float32, np.float32) ), - placements.CLIENTS, + federated_language.CLIENTS, ) -HPARAMS_TYPE = computation_types.to_type(collections.OrderedDict(a=np.int32)) +HPARAMS_TYPE = federated_language.to_type(collections.OrderedDict(a=np.int32)) MeasuredProcessOutput = measured_process.MeasuredProcessOutput _IterativeProcessConstructionError = ( @@ -57,15 +62,15 @@ def server_zero(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) def client_one(): - return intrinsics.federated_value(1.0, placements.CLIENTS) + return federated_language.federated_value(1.0, federated_language.CLIENTS) def federated_add(a, b): - return intrinsics.federated_map( + return federated_language.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (a, b) ) @@ -75,14 +80,14 @@ def tf_data_sum(data): return data.reduce(0.0, lambda x, y: x + y) -@federated_computation.federated_computation() +@federated_language.federated_computation() def test_initialize_fn(): return server_zero() def test_client_result(weights, data): - reduced_data = intrinsics.federated_map(tf_data_sum, data) - return intrinsics.federated_zip( + reduced_data = federated_language.federated_map(tf_data_sum, data) + return federated_language.federated_zip( client_works.ClientResult( update=federated_add(weights.trainable, reduced_data), update_weight=client_one(), @@ -90,14 +95,14 @@ def test_client_result(weights, data): ) -@federated_computation.federated_computation( +@federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def test_next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, data), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) @@ -121,11 +126,13 @@ def test_construction_does_not_raise(self): self.fail('Could not construct a valid ClientWorkProcess.') def test_construction_with_empty_state_does_not_raise(self): - initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value((), placements.SERVER) + initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value( + (), federated_language.SERVER + ) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE, @@ -134,7 +141,7 @@ def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, data), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) try: @@ -172,36 +179,40 @@ def test_set_hparams_not_tff_computation_raises(self): ) def test_init_param_not_empty_raises(self): - one_arg_initialize_fn = federated_computation.federated_computation( + one_arg_initialize_fn = federated_language.federated_computation( SERVER_INT )(lambda x: x) with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): client_works.ClientWorkProcess(one_arg_initialize_fn, test_next_fn) def test_init_state_not_assignable(self): - float_initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value(0.0, placements.SERVER) + float_initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value( + 0.0, federated_language.SERVER + ) ) with self.assertRaises(errors.TemplateStateNotAssignableError): client_works.ClientWorkProcess(float_initialize_fn, test_next_fn) def test_next_state_not_assignable(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def float_next_fn(state, weights, data): del state return MeasuredProcessOutput( - intrinsics.federated_value(0.0, placements.SERVER), + federated_language.federated_value(0.0, federated_language.SERVER), test_client_result(weights, data), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(errors.TemplateStateNotAssignableError): client_works.ClientWorkProcess(test_initialize_fn, float_next_fn) def test_next_return_tuple_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def tuple_next_fn(state, weights, data): @@ -215,7 +226,7 @@ def test_next_return_namedtuple_raises(self): 'MeasuredProcessOutput', ['state', 'result', 'measurements'] ) - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def namedtuple_next_fn(state, weights, data): @@ -227,7 +238,8 @@ def namedtuple_next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, namedtuple_next_fn) def test_next_return_odict_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def odict_next_fn(state, weights, data): @@ -248,7 +260,7 @@ def test_non_federated_init_next_raises(self): @tensorflow_computation.tf_computation( np.int32, MODEL_WEIGHTS_TYPE.member, - computation_types.SequenceType(np.float32), + federated_language.SequenceType(np.float32), ) def next_fn(state, weights, data): return MeasuredProcessOutput( @@ -261,11 +273,11 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(initialize_fn, next_fn) def test_init_tuple_of_federated_types_raises(self): - initialize_fn = federated_computation.federated_computation()( + initialize_fn = federated_language.federated_computation()( lambda: (server_zero(), server_zero()) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE, @@ -279,11 +291,13 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(initialize_fn, next_fn) def test_non_server_placed_init_state_raises(self): - initialize_fn = federated_computation.federated_computation( - lambda: intrinsics.federated_value(0, placements.CLIENTS) + initialize_fn = federated_language.federated_computation( + lambda: federated_language.federated_value( + 0, federated_language.CLIENTS + ) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE, @@ -297,7 +311,8 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(initialize_fn, next_fn) def test_two_param_next_raises(self): - @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE) + + @federated_language.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE) def next_fn(state, weights): return MeasuredProcessOutput(state, weights.trainable, server_zero()) @@ -306,17 +321,19 @@ def next_fn(state, weights): def test_non_clients_placed_next_weights_param_raises(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, - computation_types.FederatedType( - MODEL_WEIGHTS_TYPE.member, placements.SERVER + federated_language.FederatedType( + MODEL_WEIGHTS_TYPE.member, federated_language.SERVER ), CLIENTS_FLOAT_SEQUENCE, ) def next_fn(state, weights, data): return MeasuredProcessOutput( state, - test_client_result(intrinsics.federated_broadcast(weights), data), + test_client_result( + federated_language.federated_broadcast(weights), data + ), server_zero(), ) @@ -324,14 +341,14 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, next_fn) def test_constructs_with_non_model_weights_parameter(self): - non_model_weights_type = computation_types.FederatedType( - computation_types.to_type( + non_model_weights_type = federated_language.FederatedType( + federated_language.to_type( collections.OrderedDict(trainable=np.float32, non_trainable=()) ), - placements.CLIENTS, + federated_language.CLIENTS, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, non_model_weights_type, CLIENTS_FLOAT_SEQUENCE ) def next_fn(state, weights, data): @@ -346,25 +363,25 @@ def next_fn(state, weights, data): def test_constructs_with_struct_of_client_data_parameter(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, - computation_types.FederatedType( + federated_language.FederatedType( ( - computation_types.SequenceType(np.float32), + federated_language.SequenceType(np.float32), ( - computation_types.SequenceType(np.float32), - computation_types.SequenceType(np.float32), + federated_language.SequenceType(np.float32), + federated_language.SequenceType(np.float32), ), ), - placements.CLIENTS, + federated_language.CLIENTS, ), ) def next_fn(state, unused_weights, unused_data): return MeasuredProcessOutput( state, - intrinsics.federated_value( - client_works.ClientResult((), ()), placements.CLIENTS + federated_language.federated_value( + client_works.ClientResult((), ()), federated_language.CLIENTS ), server_zero(), ) @@ -375,17 +392,19 @@ def next_fn(state, unused_weights, unused_data): self.fail('Could not construct a valid ClientWorkProcess.') def test_non_clients_placed_next_data_param_raises(self): - server_sequence_float_type = computation_types.FederatedType( - computation_types.SequenceType(np.float32), placements.SERVER + server_sequence_float_type = federated_language.FederatedType( + federated_language.SequenceType(np.float32), federated_language.SERVER ) - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, server_sequence_float_type ) def next_fn(state, weights, data): return MeasuredProcessOutput( state, - test_client_result(weights, intrinsics.federated_broadcast(data)), + test_client_result( + weights, federated_language.federated_broadcast(data) + ), server_zero(), ) @@ -393,13 +412,14 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, next_fn) def test_non_sequence_or_struct_next_data_param_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT ) def next_fn(state, weights, data): return MeasuredProcessOutput( state, - intrinsics.federated_zip( + federated_language.federated_zip( client_works.ClientResult( federated_add(weights.trainable, data), client_one() ) @@ -411,13 +431,14 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, next_fn) def test_non_clients_placed_next_result_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def next_fn(state, weights, data): return MeasuredProcessOutput( state, - intrinsics.federated_sum(test_client_result(weights, data)), + federated_language.federated_sum(test_client_result(weights, data)), server_zero(), ) @@ -425,11 +446,12 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, next_fn) def test_non_zipped_next_result_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def next_fn(state, weights, data): - reduced_data = intrinsics.federated_map(tf_data_sum, data) + reduced_data = federated_language.federated_map(tf_data_sum, data) return MeasuredProcessOutput( state, client_works.ClientResult( @@ -442,12 +464,13 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, next_fn) def test_incorrect_client_result_container_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def next_fn(state, weights, data): - reduced_data = intrinsics.federated_map(tf_data_sum, data) - bad_client_result = intrinsics.federated_zip( + reduced_data = federated_language.federated_map(tf_data_sum, data) + bad_client_result = federated_language.federated_zip( collections.OrderedDict( update=federated_add(weights.trainable, reduced_data), update_weight=client_one(), @@ -459,14 +482,15 @@ def next_fn(state, weights, data): client_works.ClientWorkProcess(test_initialize_fn, next_fn) def test_non_server_placed_next_measurements_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT_SEQUENCE ) def next_fn(state, weights, data): return MeasuredProcessOutput( state, test_client_result(weights, data), - intrinsics.federated_value(1.0, placements.CLIENTS), + federated_language.federated_value(1.0, federated_language.CLIENTS), ) with self.assertRaises(errors.TemplatePlacementError): diff --git a/tensorflow_federated/python/learning/templates/composers.py b/tensorflow_federated/python/learning/templates/composers.py index ffecb9223d..c870c33cd4 100644 --- a/tensorflow_federated/python/learning/templates/composers.py +++ b/tensorflow_federated/python/learning/templates/composers.py @@ -17,14 +17,11 @@ from collections.abc import Callable from typing import Any, NamedTuple +import federated_language + from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning.models import model_weights as model_weights_lib @@ -61,7 +58,7 @@ class LearningAlgorithmState(NamedTuple): # pyformat: disable def compose_learning_process( - initial_model_weights_fn: computation_base.Computation, + initial_model_weights_fn: federated_language.framework.Computation, model_weights_distributor: distributors.DistributionProcess, client_work: client_works.ClientWorkProcess, model_update_aggregator: aggregation_process.AggregationProcess, @@ -138,18 +135,18 @@ def compose_learning_process( client_work, model_update_aggregator, model_finalizer) client_data_type = client_work.next.type_signature.parameter[2] # pytype: disable=unsupported-operands - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - initial_model_weights = intrinsics.federated_eval(initial_model_weights_fn, - placements.SERVER) - return intrinsics.federated_zip( + initial_model_weights = federated_language.federated_eval(initial_model_weights_fn, + federated_language.SERVER) + return federated_language.federated_zip( LearningAlgorithmState(initial_model_weights, model_weights_distributor.initialize(), client_work.initialize(), model_update_aggregator.initialize(), model_finalizer.initialize())) - @federated_computation.federated_computation(init_fn.type_signature.result, + @federated_language.federated_computation(init_fn.type_signature.result, client_data_type) def next_fn(state, client_data): # Compose processes. @@ -167,12 +164,12 @@ def next_fn(state, client_data): # Form the learning process output. new_global_model_weights = finalizer_output.result - new_state = intrinsics.federated_zip( + new_state = federated_language.federated_zip( LearningAlgorithmState(new_global_model_weights, distributor_output.state, client_work_output.state, aggregator_output.state, finalizer_output.state)) - metrics = intrinsics.federated_zip( + metrics = federated_language.federated_zip( collections.OrderedDict( distributor=distributor_output.measurements, client_work=client_work_output.measurements, @@ -248,20 +245,20 @@ def _validate_args(initial_model_weights_fn, model_weights_distributor, client_work, model_update_aggregator, model_finalizer): """Checks `compose_learning_process` args meet the documented constraints.""" py_typecheck.check_type(initial_model_weights_fn, - computation_base.Computation) + federated_language.framework.Computation) if initial_model_weights_fn.type_signature.parameter is not None: raise TypeError( f'Provided initial_model_weights_fn must be a no-arg tff.Computation.\n' f'Found input parameter: ' f'{initial_model_weights_fn.type_signature.parameter}') global_model_weights_type = initial_model_weights_fn.type_signature.result - if isinstance(global_model_weights_type, computation_types.FederatedType): + if isinstance(global_model_weights_type, federated_language.FederatedType): raise TypeError( f'Provided initial_model_weights_fn must be a tff.Computation with ' f'unplaced return type.\n' f'Return type found: {global_model_weights_type}') - global_model_weights_type = computation_types.FederatedType( - global_model_weights_type, placements.SERVER) + global_model_weights_type = federated_language.FederatedType( + global_model_weights_type, federated_language.SERVER) py_typecheck.check_type(model_weights_distributor, distributors.DistributionProcess) py_typecheck.check_type(client_work, client_works.ClientWorkProcess) diff --git a/tensorflow_federated/python/learning/templates/composers_test.py b/tensorflow_federated/python/learning/templates/composers_test.py index ee494c746a..70b975ccc3 100644 --- a/tensorflow_federated/python/learning/templates/composers_test.py +++ b/tensorflow_federated/python/learning/templates/composers_test.py @@ -15,6 +15,7 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf @@ -23,10 +24,6 @@ from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning.models import keras_utils @@ -41,20 +38,20 @@ from tensorflow_federated.python.learning.templates import learning_process from tensorflow_federated.python.learning.templates import model_delta_client_work -FLOAT_TYPE = computation_types.TensorType(np.float32) -MODEL_WEIGHTS_TYPE = computation_types.to_type( +FLOAT_TYPE = federated_language.TensorType(np.float32) +MODEL_WEIGHTS_TYPE = federated_language.to_type( model_weights_lib.ModelWeights(FLOAT_TYPE, ()) ) -CLIENTS_SEQUENCE_FLOAT_TYPE = computation_types.FederatedType( - computation_types.SequenceType(FLOAT_TYPE), placements.CLIENTS +CLIENTS_SEQUENCE_FLOAT_TYPE = federated_language.FederatedType( + federated_language.SequenceType(FLOAT_TYPE), federated_language.CLIENTS ) def empty_at_server(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) -@federated_computation.federated_computation() +@federated_language.federated_computation() def empty_init_fn(): return empty_at_server() @@ -68,13 +65,15 @@ def test_init_model_weights_fn(): def test_distributor(): - @federated_computation.federated_computation( + @federated_language.federated_computation( empty_init_fn.type_signature.result, - computation_types.FederatedType(MODEL_WEIGHTS_TYPE, placements.SERVER), + federated_language.FederatedType( + MODEL_WEIGHTS_TYPE, federated_language.SERVER + ), ) def next_fn(state, value): return measured_process.MeasuredProcessOutput( - state, intrinsics.federated_broadcast(value), empty_at_server() + state, federated_language.federated_broadcast(value), empty_at_server() ) return distributors.DistributionProcess(empty_init_fn, next_fn) @@ -88,13 +87,15 @@ def make_result(value, data): update_weight=data.reduce(0.0, lambda x, y: x + y), ) - @federated_computation.federated_computation( + @federated_language.federated_computation( empty_init_fn.type_signature.result, - computation_types.FederatedType(MODEL_WEIGHTS_TYPE, placements.CLIENTS), + federated_language.FederatedType( + MODEL_WEIGHTS_TYPE, federated_language.CLIENTS + ), CLIENTS_SEQUENCE_FLOAT_TYPE, ) def next_fn(state, value, client_data): - result = intrinsics.federated_map(make_result, (value, client_data)) + result = federated_language.federated_map(make_result, (value, client_data)) return measured_process.MeasuredProcessOutput( state, result, empty_at_server() ) @@ -108,17 +109,19 @@ def test_aggregator(): def test_finalizer(): - @federated_computation.federated_computation( + @federated_language.federated_computation( empty_init_fn.type_signature.result, - computation_types.FederatedType(MODEL_WEIGHTS_TYPE, placements.SERVER), - computation_types.FederatedType(FLOAT_TYPE, placements.SERVER), + federated_language.FederatedType( + MODEL_WEIGHTS_TYPE, federated_language.SERVER + ), + federated_language.FederatedType(FLOAT_TYPE, federated_language.SERVER), ) def next_fn(state, weights, updates): - new_weights = intrinsics.federated_map( + new_weights = federated_language.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (weights.trainable, updates), ) - new_weights = intrinsics.federated_zip( + new_weights = federated_language.federated_zip( model_weights_lib.ModelWeights(new_weights, ()) ) return measured_process.MeasuredProcessOutput( @@ -163,7 +166,7 @@ def test_learning_process_composes(self): def test_one_arg_computation_init_raises(self): @tensorflow_computation.tf_computation( - computation_types.TensorType(np.float32) + federated_language.TensorType(np.float32) ) def init_model_weights_fn(x): return model_weights_lib.ModelWeights(trainable=x, non_trainable=()) @@ -193,10 +196,11 @@ def init_model_weights_fn(): ) def test_federated_init_raises(self): - @federated_computation.federated_computation() + + @federated_language.federated_computation() def init_model_weights_fn(): - return intrinsics.federated_eval( - test_init_model_weights_fn, placements.SERVER + return federated_language.federated_eval( + test_init_model_weights_fn, federated_language.SERVER ) with self.assertRaisesRegex(TypeError, 'unplaced'): diff --git a/tensorflow_federated/python/learning/templates/distributors.py b/tensorflow_federated/python/learning/templates/distributors.py index 623ef7b0e6..3cea1f0749 100644 --- a/tensorflow_federated/python/learning/templates/distributors.py +++ b/tensorflow_federated/python/learning/templates/distributors.py @@ -17,13 +17,9 @@ `tff.distributors` and `tff.templates` later on. """ +import federated_language from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process @@ -54,7 +50,7 @@ def __init__(self, initialize_fn, next_fn): super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not isinstance( - initialize_fn.type_signature.result, computation_types.FederatedType + initialize_fn.type_signature.result, federated_language.FederatedType ): raise errors.TemplateNotFederatedError( 'Provided `initialize_fn` must return a federated type, but found ' @@ -66,15 +62,13 @@ def __init__(self, initialize_fn, next_fn): next_fn.type_signature.parameter ) + structure.flatten(next_fn.type_signature.result) if not all( - [isinstance(t, computation_types.FederatedType) for t in next_types] + [isinstance(t, federated_language.FederatedType) for t in next_types] ): - offending_types = '\n- '.join( - [ - t - for t in next_types - if not isinstance(t, computation_types.FederatedType) - ] - ) + offending_types = '\n- '.join([ + t + for t in next_types + if not isinstance(t, federated_language.FederatedType) + ]) raise errors.TemplateNotFederatedError( 'Provided `next_fn` must be a *federated* computation, that is, ' 'operate on `tff.FederatedType`s, but found\n' @@ -82,7 +76,10 @@ def __init__(self, initialize_fn, next_fn): f'The non-federated types are:\n {offending_types}.' ) - if initialize_fn.type_signature.result.placement != placements.SERVER: + if ( + initialize_fn.type_signature.result.placement + != federated_language.SERVER + ): raise errors.TemplatePlacementError( 'The state controlled by an `DistributionProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' @@ -93,7 +90,7 @@ def __init__(self, initialize_fn, next_fn): next_fn_param = next_fn.type_signature.parameter next_fn_result = next_fn.type_signature.result - if not isinstance(next_fn_param, computation_types.StructType): + if not isinstance(next_fn_param, federated_language.StructType): raise errors.TemplateNextFnNumArgsError( 'The `next_fn` must have exactly two input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' @@ -104,18 +101,18 @@ def __init__(self, initialize_fn, next_fn): 'The `next_fn` must have exactly two input arguments, but found ' f'{len(next_fn_param)} input arguments:\n{next_param_str}' ) - if next_fn_param[1].placement != placements.SERVER: + if next_fn_param[1].placement != federated_language.SERVER: raise errors.TemplatePlacementError( 'The second input argument of `next_fn` must be placed at SERVER ' f'but found {next_fn_param[1]}.' ) - if next_fn_result.result.placement != placements.CLIENTS: + if next_fn_result.result.placement != federated_language.CLIENTS: raise errors.TemplatePlacementError( 'The "result" attribute of return type of `next_fn` must be placed ' f'at CLIENTS, but found {next_fn_result.result}.' ) - if next_fn_result.measurements.placement != placements.SERVER: + if next_fn_result.measurements.placement != federated_language.SERVER: raise errors.TemplatePlacementError( 'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.' @@ -123,7 +120,7 @@ def __init__(self, initialize_fn, next_fn): # TODO: b/190334722 - Replace with a factory pattern similar to tff.aggregators. -def build_broadcast_process(value_type: computation_types.Type): +def build_broadcast_process(value_type: federated_language.Type): """Builds `DistributionProcess` directly broadcasting values. The created process has empty state and reports no measurements. @@ -138,26 +135,28 @@ def build_broadcast_process(value_type: computation_types.Type): TypeError: If `value_type` contains a `tff.types.FederatedType`. """ py_typecheck.check_type( - value_type, (computation_types.TensorType, computation_types.StructType) + value_type, (federated_language.TensorType, federated_language.StructType) ) - if type_analysis.contains_federated_types(value_type): + if federated_language.framework.contains_federated_types(value_type): raise TypeError( 'Provided value_type must not contain any tff.types.FederatedType, ' f'but found: {value_type}' ) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(value_type, placements.SERVER), + federated_language.FederatedType(value_type, federated_language.SERVER), ) def next_fn(state, value): - empty_measurements = intrinsics.federated_value((), placements.SERVER) + empty_measurements = federated_language.federated_value( + (), federated_language.SERVER + ) return measured_process.MeasuredProcessOutput( - state, intrinsics.federated_broadcast(value), empty_measurements + state, federated_language.federated_broadcast(value), empty_measurements ) return DistributionProcess(init_fn, next_fn) diff --git a/tensorflow_federated/python/learning/templates/distributors_test.py b/tensorflow_federated/python/learning/templates/distributors_test.py index 0f0b3cf0b3..fb74d26109 100644 --- a/tensorflow_federated/python/learning/templates/distributors_test.py +++ b/tensorflow_federated/python/learning/templates/distributors_test.py @@ -15,22 +15,25 @@ import collections from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.templates import distributors -SERVER_INT = computation_types.FederatedType(np.int32, placements.SERVER) -SERVER_FLOAT = computation_types.FederatedType(np.float32, placements.SERVER) -CLIENTS_INT = computation_types.FederatedType(np.int32, placements.CLIENTS) +SERVER_INT = federated_language.FederatedType( + np.int32, federated_language.SERVER +) +SERVER_FLOAT = federated_language.FederatedType( + np.float32, federated_language.SERVER +) +CLIENTS_INT = federated_language.FederatedType( + np.int32, federated_language.CLIENTS +) MeasuredProcessOutput = measured_process.MeasuredProcessOutput _DistributionProcessConstructionError = ( @@ -41,20 +44,20 @@ def server_zero(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) -@federated_computation.federated_computation() +@federated_language.federated_computation() def test_initialize_fn(): return server_zero() -@federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) +@federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def test_next_fn(state, val): return MeasuredProcessOutput( state, - intrinsics.federated_broadcast(val), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_broadcast(val), + federated_language.federated_value(1, federated_language.SERVER), ) @@ -67,18 +70,20 @@ def test_construction_does_not_raise(self): self.fail('Could not construct a valid DistributionProcess.') def test_construction_with_empty_state_does_not_raise(self): - initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value((), placements.SERVER) + initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value( + (), federated_language.SERVER + ) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, SERVER_FLOAT ) def next_fn(state, val): return MeasuredProcessOutput( state, - intrinsics.federated_broadcast(val), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_broadcast(val), + federated_language.federated_value(1, federated_language.SERVER), ) try: @@ -100,26 +105,29 @@ def test_next_not_tff_computation_raises(self): ) def test_init_param_not_empty_raises(self): - one_arg_initialize_fn = federated_computation.federated_computation( + one_arg_initialize_fn = federated_language.federated_computation( SERVER_INT )(lambda x: x) with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): distributors.DistributionProcess(one_arg_initialize_fn, test_next_fn) def test_init_state_not_assignable(self): - float_initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value(0.0, placements.SERVER) + float_initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value( + 0.0, federated_language.SERVER + ) ) with self.assertRaises(errors.TemplateStateNotAssignableError): distributors.DistributionProcess(float_initialize_fn, test_next_fn) def test_next_state_not_assignable(self): - @federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) + + @federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def float_next_fn(state, val): del state return MeasuredProcessOutput( - intrinsics.federated_value(0.0, placements.SERVER), - intrinsics.federated_broadcast(val), + federated_language.federated_value(0.0, federated_language.SERVER), + federated_language.federated_broadcast(val), server_zero(), ) @@ -127,9 +135,10 @@ def float_next_fn(state, val): distributors.DistributionProcess(test_initialize_fn, float_next_fn) def test_next_return_tuple_raises(self): - @federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) + + @federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def tuple_next_fn(state, val): - return state, intrinsics.federated_broadcast(val), server_zero() + return state, federated_language.federated_broadcast(val), server_zero() with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): distributors.DistributionProcess(test_initialize_fn, tuple_next_fn) @@ -139,21 +148,22 @@ def test_next_return_namedtuple_raises(self): 'MeasuredProcessOutput', ['state', 'result', 'measurements'] ) - @federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) + @federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def namedtuple_next_fn(state, val): return measured_process_output( - state, intrinsics.federated_broadcast(val), server_zero() + state, federated_language.federated_broadcast(val), server_zero() ) with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError): distributors.DistributionProcess(test_initialize_fn, namedtuple_next_fn) def test_next_return_odict_raises(self): - @federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) + + @federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def odict_next_fn(state, val): return collections.OrderedDict( state=state, - result=intrinsics.federated_broadcast(val), + result=federated_language.federated_broadcast(val), measurements=server_zero(), ) @@ -167,13 +177,15 @@ def test_construction_with_value_type_mismatch_does_not_raise(self): lambda x: tf.cast(x, tf.float64) ) - @federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) + @federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def next_fn(state, val): - result = intrinsics.federated_map( - bad_cast_fn, intrinsics.federated_broadcast(val) + result = federated_language.federated_map( + bad_cast_fn, federated_language.federated_broadcast(val) ) return MeasuredProcessOutput( - state, result, intrinsics.federated_value(1, placements.SERVER) + state, + result, + federated_language.federated_value(1, federated_language.SERVER), ) try: @@ -195,78 +207,83 @@ def next_fn(state, val): distributors.DistributionProcess(initialize_fn, next_fn) def test_init_tuple_of_federated_types_raises(self): - initialize_fn = federated_computation.federated_computation()( + initialize_fn = federated_language.federated_computation()( lambda: (server_zero(), server_zero()) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, SERVER_FLOAT ) def next_fn(state, val): return MeasuredProcessOutput( - state, intrinsics.federated_broadcast(val), server_zero() + state, federated_language.federated_broadcast(val), server_zero() ) with self.assertRaises(errors.TemplateNotFederatedError): distributors.DistributionProcess(initialize_fn, next_fn) def test_non_server_placed_init_state_raises(self): - initialize_fn = federated_computation.federated_computation( - lambda: intrinsics.federated_value(0, placements.CLIENTS) + initialize_fn = federated_language.federated_computation( + lambda: federated_language.federated_value( + 0, federated_language.CLIENTS + ) ) - @federated_computation.federated_computation(CLIENTS_INT, SERVER_FLOAT) + @federated_language.federated_computation(CLIENTS_INT, SERVER_FLOAT) def next_fn(state, val): return MeasuredProcessOutput( - state, intrinsics.federated_broadcast(val), server_zero() + state, federated_language.federated_broadcast(val), server_zero() ) with self.assertRaises(errors.TemplatePlacementError): distributors.DistributionProcess(initialize_fn, next_fn) def test_single_param_next_raises(self): - @federated_computation.federated_computation(SERVER_INT) + + @federated_language.federated_computation(SERVER_INT) def next_fn(state): return MeasuredProcessOutput( - state, intrinsics.federated_broadcast(state), server_zero() + state, federated_language.federated_broadcast(state), server_zero() ) with self.assertRaises(errors.TemplateNextFnNumArgsError): distributors.DistributionProcess(test_initialize_fn, next_fn) def test_three_params_next_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, SERVER_FLOAT, SERVER_FLOAT ) def next_fn(state, value, extra_value): return MeasuredProcessOutput( - state, intrinsics.federated_broadcast(value), extra_value + state, federated_language.federated_broadcast(value), extra_value ) with self.assertRaises(errors.TemplateNextFnNumArgsError): distributors.DistributionProcess(test_initialize_fn, next_fn) def test_non_server_placed_next_value_param_raises(self): - next_fn = federated_computation.federated_computation( - SERVER_INT, CLIENTS_INT - )(lambda state, val: MeasuredProcessOutput(state, val, server_zero())) + next_fn = federated_language.federated_computation(SERVER_INT, CLIENTS_INT)( + lambda state, val: MeasuredProcessOutput(state, val, server_zero()) + ) with self.assertRaises(errors.TemplatePlacementError): distributors.DistributionProcess(test_initialize_fn, next_fn) def test_non_clients_placed_next_result_raises(self): - next_fn = federated_computation.federated_computation( - SERVER_INT, SERVER_INT - )(lambda state, val: MeasuredProcessOutput(state, val, server_zero())) + next_fn = federated_language.federated_computation(SERVER_INT, SERVER_INT)( + lambda state, val: MeasuredProcessOutput(state, val, server_zero()) + ) with self.assertRaises(errors.TemplatePlacementError): distributors.DistributionProcess(test_initialize_fn, next_fn) def test_non_server_placed_next_measurements_raises(self): - @federated_computation.federated_computation(SERVER_INT, SERVER_FLOAT) + + @federated_language.federated_computation(SERVER_INT, SERVER_FLOAT) def next_fn(state, val): return MeasuredProcessOutput( state, - intrinsics.federated_broadcast(val), - intrinsics.federated_value(1.0, placements.CLIENTS), + federated_language.federated_broadcast(val), + federated_language.federated_value(1.0, federated_language.CLIENTS), ) with self.assertRaises(errors.TemplatePlacementError): @@ -276,32 +293,34 @@ def next_fn(state, val): class BroadcastProcessComputationTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( - ('float', computation_types.TensorType(np.float32)), - ('struct', computation_types.to_type([(np.float32, (2,)), np.int32])), + ('float', federated_language.TensorType(np.float32)), + ('struct', federated_language.to_type([(np.float32, (2,)), np.int32])), ) def test_type_properties(self, value_type): broadcast_process = distributors.build_broadcast_process(value_type) self.assertIsInstance(broadcast_process, distributors.DistributionProcess) - expected_param_value_type = computation_types.FederatedType( - value_type, placements.SERVER + expected_param_value_type = federated_language.FederatedType( + value_type, federated_language.SERVER + ) + expected_result_type = federated_language.FederatedType( + value_type, federated_language.CLIENTS, all_equal=True ) - expected_result_type = computation_types.FederatedType( - value_type, placements.CLIENTS, all_equal=True + expected_state_type = federated_language.FederatedType( + (), federated_language.SERVER ) - expected_state_type = computation_types.FederatedType((), placements.SERVER) - expected_measurements_type = computation_types.FederatedType( - (), placements.SERVER + expected_measurements_type = federated_language.FederatedType( + (), federated_language.SERVER ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) expected_initialize_type.check_equivalent_to( broadcast_process.initialize.type_signature ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=expected_param_value_type ), @@ -317,8 +336,8 @@ def test_type_properties(self, value_type): @parameterized.named_parameters( ('federated_type', SERVER_FLOAT), - ('function_type', computation_types.FunctionType(None, ())), - ('sequence_type', computation_types.SequenceType(np.float32)), + ('function_type', federated_language.FunctionType(None, ())), + ('sequence_type', federated_language.SequenceType(np.float32)), ) def test_incorrect_value_type_raises(self, bad_value_type): with self.assertRaises(TypeError): @@ -327,7 +346,7 @@ def test_incorrect_value_type_raises(self, bad_value_type): def test_inner_federated_type_raises(self): with self.assertRaisesRegex(TypeError, 'FederatedType'): distributors.build_broadcast_process( - computation_types.to_type([SERVER_FLOAT, SERVER_FLOAT]) + federated_language.to_type([SERVER_FLOAT, SERVER_FLOAT]) ) @@ -343,7 +362,7 @@ def test_broadcast_scalar(self): self.assertEqual((), output.measurements) def test_broadcast_struct(self): - struct_type = computation_types.to_type([(np.float32, (2,)), np.int32]) + struct_type = federated_language.to_type([(np.float32, (2,)), np.int32]) broadcast_process = distributors.build_broadcast_process(struct_type) output = broadcast_process.next( broadcast_process.initialize(), ((1.0, 2.5), 3) diff --git a/tensorflow_federated/python/learning/templates/finalizers.py b/tensorflow_federated/python/learning/templates/finalizers.py index 7c1dde8284..1927c1e889 100644 --- a/tensorflow_federated/python/learning/templates/finalizers.py +++ b/tensorflow_federated/python/learning/templates/finalizers.py @@ -15,12 +15,9 @@ from typing import Optional +import federated_language + from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.templates import hparams_base @@ -40,11 +37,11 @@ class FinalizerProcess(measured_process.MeasuredProcess): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, *, - get_hparams_fn: Optional[computation_base.Computation] = None, - set_hparams_fn: Optional[computation_base.Computation] = None, + get_hparams_fn: Optional[federated_language.framework.Computation] = None, + set_hparams_fn: Optional[federated_language.framework.Computation] = None, ): """Initializes a `FinalizerProcess`. @@ -106,7 +103,7 @@ def __init__( super().__init__(initialize_fn, next_fn, next_is_multi_arg=True) if not isinstance( - initialize_fn.type_signature.result, computation_types.FederatedType + initialize_fn.type_signature.result, federated_language.FederatedType ): raise errors.TemplateNotFederatedError( 'Provided `initialize_fn` must return a federated type, but found ' @@ -118,15 +115,13 @@ def __init__( next_fn.type_signature.parameter ) + structure.flatten(next_fn.type_signature.result) if not all( - [isinstance(t, computation_types.FederatedType) for t in next_types] + [isinstance(t, federated_language.FederatedType) for t in next_types] ): - offending_types = '\n- '.join( - [ - t - for t in next_types - if not isinstance(t, computation_types.FederatedType) - ] - ) + offending_types = '\n- '.join([ + t + for t in next_types + if not isinstance(t, federated_language.FederatedType) + ]) raise errors.TemplateNotFederatedError( 'Provided `next_fn` must be a *federated* computation, that is, ' 'operate on `tff.FederatedType`s, but found\n' @@ -134,7 +129,10 @@ def __init__( f'The non-federated types are:\n {offending_types}.' ) - if initialize_fn.type_signature.result.placement != placements.SERVER: + if ( + initialize_fn.type_signature.result.placement + != federated_language.SERVER + ): raise errors.TemplatePlacementError( 'The state controlled by an `FinalizerProcess` must be placed at ' f'the SERVER, but found type: {initialize_fn.type_signature.result}.' @@ -144,7 +142,7 @@ def __init__( # TemplateStateNotAssignableError. next_fn_param = next_fn.type_signature.parameter - if not isinstance(next_fn_param, computation_types.StructType): + if not isinstance(next_fn_param, federated_language.StructType): raise errors.TemplateNextFnNumArgsError( 'The `next_fn` must have exactly two input arguments, but found ' f'the following input type which is not a Struct: {next_fn_param}.' @@ -157,19 +155,19 @@ def __init__( ) model_weights_param = next_fn_param[1] update_from_clients_param = next_fn_param[2] - if model_weights_param.placement != placements.SERVER: + if model_weights_param.placement != federated_language.SERVER: raise errors.TemplatePlacementError( 'The second input argument of `next_fn` must be placed at SERVER ' f'but found {model_weights_param}.' ) - if update_from_clients_param.placement != placements.SERVER: + if update_from_clients_param.placement != federated_language.SERVER: raise errors.TemplatePlacementError( 'The third input argument of `next_fn` must be placed at SERVER ' f'but found {update_from_clients_param}.' ) next_fn_result = next_fn.type_signature.result - if next_fn_result.result.placement != placements.SERVER: + if next_fn_result.result.placement != federated_language.SERVER: raise errors.TemplatePlacementError( 'The "result" attribute of the return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.result}.' @@ -183,7 +181,7 @@ def __init__( f'Second input argument: {next_fn_param[1].member}\n' f'Result attribute: {next_fn_result.result.member}.' ) - if next_fn_result.measurements.placement != placements.SERVER: + if next_fn_result.measurements.placement != federated_language.SERVER: raise errors.TemplatePlacementError( 'The "measurements" attribute of return type of `next_fn` must be ' f'placed at SERVER, but found {next_fn_result.measurements}.' @@ -208,35 +206,39 @@ def __init__( self._set_hparams_fn = set_hparams_fn @property - def get_hparams(self) -> computation_base.Computation: + def get_hparams(self) -> federated_language.framework.Computation: return self._get_hparams_fn # pytype: disable=attribute-error @property - def set_hparams(self) -> computation_base.Computation: + def set_hparams(self) -> federated_language.framework.Computation: return self._set_hparams_fn # pytype: disable=attribute-error def build_identity_finalizer( - model_weights_type: computation_types.StructType, - update_type: computation_types.StructType, + model_weights_type: federated_language.StructType, + update_type: federated_language.StructType, ) -> FinalizerProcess: """Builds a `FinalizerProcess` that performs no update on model weights.""" - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) # The type signature of `next` function is defined so that the created # `tff.learning.templates.FinalizerProcess` can be used in # `tff.learning.templates.compose_learning_process`. - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(model_weights_type, placements.SERVER), - computation_types.FederatedType(update_type, placements.SERVER), + federated_language.FederatedType( + model_weights_type, federated_language.SERVER + ), + federated_language.FederatedType(update_type, federated_language.SERVER), ) def next_fn(state, weights, update): del update - empty_measurements = intrinsics.federated_value((), placements.SERVER) + empty_measurements = federated_language.federated_value( + (), federated_language.SERVER + ) return measured_process.MeasuredProcessOutput( state, weights, empty_measurements ) diff --git a/tensorflow_federated/python/learning/templates/finalizers_test.py b/tensorflow_federated/python/learning/templates/finalizers_test.py index da959519ae..85458828e5 100644 --- a/tensorflow_federated/python/learning/templates/finalizers_test.py +++ b/tensorflow_federated/python/learning/templates/finalizers_test.py @@ -14,29 +14,34 @@ import collections +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.templates import finalizers from tensorflow_federated.python.learning.templates import hparams_base -SERVER_INT = computation_types.FederatedType(np.int32, placements.SERVER) -SERVER_FLOAT = computation_types.FederatedType(np.float32, placements.SERVER) -CLIENTS_INT = computation_types.FederatedType(np.int32, placements.CLIENTS) -CLIENTS_FLOAT = computation_types.FederatedType(np.float32, placements.CLIENTS) -MODEL_WEIGHTS_TYPE = computation_types.FederatedType( - computation_types.to_type( +SERVER_INT = federated_language.FederatedType( + np.int32, federated_language.SERVER +) +SERVER_FLOAT = federated_language.FederatedType( + np.float32, federated_language.SERVER +) +CLIENTS_INT = federated_language.FederatedType( + np.int32, federated_language.CLIENTS +) +CLIENTS_FLOAT = federated_language.FederatedType( + np.float32, federated_language.CLIENTS +) +MODEL_WEIGHTS_TYPE = federated_language.FederatedType( + federated_language.to_type( model_weights.ModelWeights(np.float32, np.float32) ), - placements.SERVER, + federated_language.SERVER, ) MeasuredProcessOutput = measured_process.MeasuredProcessOutput @@ -52,36 +57,36 @@ def server_zero(): """Returns zero integer placed at SERVER.""" - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) def federated_add(a, b): - return intrinsics.federated_map( + return federated_language.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (a, b) ) -@federated_computation.federated_computation() +@federated_language.federated_computation() def test_initialize_fn(): return server_zero() def test_finalizer_result(weights, update): - return intrinsics.federated_zip( + return federated_language.federated_zip( model_weights.ModelWeights( federated_add(weights.trainable, update), weights.non_trainable ) ) -@federated_computation.federated_computation( +@federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def test_next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) @@ -94,18 +99,20 @@ def test_construction_does_not_raise(self): self.fail('Could not construct a valid FinalizerProcess.') def test_construction_with_empty_state_does_not_raise(self): - initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value((), placements.SERVER) + initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value( + (), federated_language.SERVER + ) ) - model_weights_type = computation_types.StructWithPythonType( + model_weights_type = federated_language.StructWithPythonType( [('trainable', np.float32), ('non_trainable', ())], model_weights.ModelWeights, ) - server_model_weights_type = computation_types.FederatedType( - model_weights_type, placements.SERVER + server_model_weights_type = federated_language.FederatedType( + model_weights_type, federated_language.SERVER ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, server_model_weights_type, SERVER_FLOAT, @@ -114,7 +121,7 @@ def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) try: @@ -134,36 +141,40 @@ def test_next_not_tff_computation_raises(self): ) def test_init_param_not_empty_raises(self): - one_arg_initialize_fn = federated_computation.federated_computation( + one_arg_initialize_fn = federated_language.federated_computation( SERVER_INT )(lambda x: x) with self.assertRaises(errors.TemplateInitFnParamNotEmptyError): finalizers.FinalizerProcess(one_arg_initialize_fn, test_next_fn) def test_init_state_not_assignable(self): - float_initialize_fn = federated_computation.federated_computation()( - lambda: intrinsics.federated_value(0.0, placements.SERVER) + float_initialize_fn = federated_language.federated_computation()( + lambda: federated_language.federated_value( + 0.0, federated_language.SERVER + ) ) with self.assertRaises(errors.TemplateStateNotAssignableError): finalizers.FinalizerProcess(float_initialize_fn, test_next_fn) def test_next_state_not_assignable(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def float_next_fn(state, weights, update): del state return MeasuredProcessOutput( - intrinsics.federated_value(0.0, placements.SERVER), + federated_language.federated_value(0.0, federated_language.SERVER), test_finalizer_result(weights, update), - intrinsics.federated_value(1, placements.SERVER), + federated_language.federated_value(1, federated_language.SERVER), ) with self.assertRaises(errors.TemplateStateNotAssignableError): finalizers.FinalizerProcess(test_initialize_fn, float_next_fn) def test_next_return_tuple_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def tuple_next_fn(state, weights, update): @@ -177,7 +188,7 @@ def test_next_return_namedtuple_raises(self): 'MeasuredProcessOutput', ['state', 'result', 'measurements'] ) - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def namedtuple_next_fn(state, weights, update): @@ -189,7 +200,8 @@ def namedtuple_next_fn(state, weights, update): finalizers.FinalizerProcess(test_initialize_fn, namedtuple_next_fn) def test_next_return_odict_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def odict_next_fn(state, weights, update): @@ -206,7 +218,7 @@ def odict_next_fn(state, weights, update): def test_non_federated_init_next_raises(self): initialize_fn = tensorflow_computation.tf_computation(lambda: 0) - model_weights_type = computation_types.StructWithPythonType( + model_weights_type = federated_language.StructWithPythonType( [('trainable', np.float32), ('non_trainable', ())], model_weights.ModelWeights, ) @@ -224,11 +236,11 @@ def next_fn(state, weights, update): finalizers.FinalizerProcess(initialize_fn, next_fn) def test_init_tuple_of_federated_types_raises(self): - initialize_fn = federated_computation.federated_computation()( + initialize_fn = federated_language.federated_computation()( lambda: (server_zero(), server_zero()) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def next_fn(state, weights, update): @@ -240,11 +252,13 @@ def next_fn(state, weights, update): finalizers.FinalizerProcess(initialize_fn, next_fn) def test_non_server_placed_init_state_raises(self): - initialize_fn = federated_computation.federated_computation( - lambda: intrinsics.federated_value(0, placements.CLIENTS) + initialize_fn = federated_language.federated_computation( + lambda: federated_language.federated_value( + 0, federated_language.CLIENTS + ) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( CLIENTS_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def next_fn(state, weights, update): @@ -256,7 +270,8 @@ def next_fn(state, weights, update): finalizers.FinalizerProcess(initialize_fn, next_fn) def test_two_param_next_raises(self): - @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE) + + @federated_language.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE) def next_fn(state, weights): return MeasuredProcessOutput(state, weights, server_zero()) @@ -265,17 +280,19 @@ def next_fn(state, weights): def test_non_server_placed_next_weight_param_raises(self): - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, - computation_types.FederatedType( - MODEL_WEIGHTS_TYPE.member, placements.CLIENTS + federated_language.FederatedType( + MODEL_WEIGHTS_TYPE.member, federated_language.CLIENTS ), SERVER_FLOAT, ) def next_fn(state, weights, update): return MeasuredProcessOutput( state, - test_finalizer_result(intrinsics.federated_sum(weights), update), + test_finalizer_result( + federated_language.federated_sum(weights), update + ), server_zero(), ) @@ -283,14 +300,14 @@ def next_fn(state, weights, update): finalizers.FinalizerProcess(test_initialize_fn, next_fn) def test_constructs_with_non_model_weights_parameter(self): - non_model_weights_type = computation_types.FederatedType( - computation_types.to_type( + non_model_weights_type = federated_language.FederatedType( + federated_language.to_type( collections.OrderedDict(trainable=np.float32, non_trainable=()) ), - placements.SERVER, + federated_language.SERVER, ) - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, non_model_weights_type, SERVER_FLOAT ) def next_fn(state, weights, update): @@ -303,13 +320,16 @@ def next_fn(state, weights, update): self.fail('Could not construct a valid FinalizerProcess.') def test_non_server_placed_next_update_param_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, CLIENTS_FLOAT ) def next_fn(state, weights, update): return MeasuredProcessOutput( state, - test_finalizer_result(weights, intrinsics.federated_sum(update)), + test_finalizer_result( + weights, federated_language.federated_sum(update) + ), server_zero(), ) @@ -317,13 +337,14 @@ def next_fn(state, weights, update): finalizers.FinalizerProcess(test_initialize_fn, next_fn) def test_non_server_placed_next_result_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def next_fn(state, weights, update): return MeasuredProcessOutput( state, - intrinsics.federated_broadcast( + federated_language.federated_broadcast( test_finalizer_result(weights, update) ), server_zero(), @@ -337,13 +358,13 @@ def test_result_not_assignable_to_weight_raises(self): lambda x: tf.nest.map_structure(lambda y: tf.cast(y, tf.float64), x) ) - @federated_computation.federated_computation( + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def next_fn(state, weights, update): return MeasuredProcessOutput( state, - intrinsics.federated_map( + federated_language.federated_map( bad_cast_fn, test_finalizer_result(weights, update) ), server_zero(), @@ -353,14 +374,15 @@ def next_fn(state, weights, update): finalizers.FinalizerProcess(test_initialize_fn, next_fn) def test_non_server_placed_next_measurements_raises(self): - @federated_computation.federated_computation( + + @federated_language.federated_computation( SERVER_INT, MODEL_WEIGHTS_TYPE, SERVER_FLOAT ) def next_fn(state, weights, update): return MeasuredProcessOutput( state, test_finalizer_result(weights, update), - intrinsics.federated_value(1.0, placements.CLIENTS), + federated_language.federated_value(1.0, federated_language.CLIENTS), ) with self.assertRaises(errors.TemplatePlacementError): diff --git a/tensorflow_federated/python/learning/templates/hparams_base.py b/tensorflow_federated/python/learning/templates/hparams_base.py index ea2ff8865c..33246c80e6 100644 --- a/tensorflow_federated/python/learning/templates/hparams_base.py +++ b/tensorflow_federated/python/learning/templates/hparams_base.py @@ -15,10 +15,10 @@ import collections +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types class GetHparamsTypeError(TypeError): @@ -30,11 +30,13 @@ class SetHparamsTypeError(TypeError): def type_check_get_hparams_fn( - get_hparams_fn: computation_base.Computation, - state_type: computation_types.Type, + get_hparams_fn: federated_language.framework.Computation, + state_type: federated_language.Type, ): """Validates the type signature of `get_hparams_fn` in `ClientWorkProcess`.""" - py_typecheck.check_type(get_hparams_fn, computation_base.Computation) + py_typecheck.check_type( + get_hparams_fn, federated_language.framework.Computation + ) get_hparams_state_type = get_hparams_fn.type_signature.parameter if ( get_hparams_state_type is None @@ -47,14 +49,16 @@ def type_check_get_hparams_fn( def type_check_set_hparams_fn( - set_hparams_fn: computation_base.Computation, - state_type: computation_types.Type, + set_hparams_fn: federated_language.framework.Computation, + state_type: federated_language.Type, ): """Validates the type signature of `set_hparams_fn` in `ClientWorkProcess`.""" - py_typecheck.check_type(set_hparams_fn, computation_base.Computation) + py_typecheck.check_type( + set_hparams_fn, federated_language.framework.Computation + ) set_hparams_parameter = set_hparams_fn.type_signature.parameter if ( - not isinstance(set_hparams_parameter, computation_types.StructType) + not isinstance(set_hparams_parameter, federated_language.StructType) or len(set_hparams_parameter) != 2 ): raise SetHparamsTypeError( @@ -76,8 +80,8 @@ def type_check_set_hparams_fn( def build_basic_hparams_getter( - state_type: computation_types.Type, -) -> computation_base.Computation: + state_type: federated_language.Type, +) -> federated_language.framework.Computation: """Creates a `tff.Computation` that returns an empty ordered dictionary.""" @tensorflow_computation.tf_computation(state_type) @@ -89,8 +93,8 @@ def get_hparams_computation(state): def build_basic_hparams_setter( - state_type: computation_types.Type, hparams_type: computation_types.Type -) -> computation_base.Computation: + state_type: federated_language.Type, hparams_type: federated_language.Type +) -> federated_language.framework.Computation: """Creates a `tff.Computation` that returns the state, unchanged.""" @tensorflow_computation.tf_computation(state_type, hparams_type) diff --git a/tensorflow_federated/python/learning/templates/hparams_base_test.py b/tensorflow_federated/python/learning/templates/hparams_base_test.py index b07ad0f2e4..4697aa5dfb 100644 --- a/tensorflow_federated/python/learning/templates/hparams_base_test.py +++ b/tensorflow_federated/python/learning/templates/hparams_base_test.py @@ -15,18 +15,17 @@ import collections from absl.testing import absltest +import federated_language import numpy as np from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.learning.templates import hparams_base class HparamsBaseTest(absltest.TestCase): def test_get_hparams_with_compatible_state_type_does_not_raise(self): - state_type = computation_types.TensorType(np.int32) + state_type = federated_language.TensorType(np.int32) @tensorflow_computation.tf_computation(np.int32) def get_hparams_fn(state): @@ -35,7 +34,7 @@ def get_hparams_fn(state): hparams_base.type_check_get_hparams_fn(get_hparams_fn, state_type) def test_get_hparams_with_incompatible_state_type(self): - state_type = computation_types.TensorType(np.int32) + state_type = federated_language.TensorType(np.int32) @tensorflow_computation.tf_computation(np.float32) def get_hparams_fn(state): @@ -45,7 +44,7 @@ def get_hparams_fn(state): hparams_base.type_check_get_hparams_fn(get_hparams_fn, state_type) def test_set_hparams_fn_with_one_input_arg_raises(self): - state_type = computation_types.TensorType(np.int32) + state_type = federated_language.TensorType(np.int32) @tensorflow_computation.tf_computation(np.int32) def set_hparams_fn(state): @@ -55,7 +54,7 @@ def set_hparams_fn(state): hparams_base.type_check_set_hparams_fn(set_hparams_fn, state_type) def test_set_hparams_fn_with_three_input_args_raises(self): - state_type = computation_types.TensorType(np.int32) + state_type = federated_language.TensorType(np.int32) @tensorflow_computation.tf_computation(np.int32, np.int32, np.int32) def set_hparams_fn(state, x, y): @@ -67,8 +66,8 @@ def set_hparams_fn(state, x, y): hparams_base.type_check_set_hparams_fn(set_hparams_fn, state_type) def test_set_hparams_fn_with_compatible_state_type_does_not_raise(self): - state_type = computation_types.TensorType(np.int32) - hparams_type = computation_types.to_type( + state_type = federated_language.TensorType(np.int32) + hparams_type = federated_language.to_type( collections.OrderedDict(a=np.int32) ) @@ -80,8 +79,8 @@ def set_hparams_fn(state, hparams): hparams_base.type_check_set_hparams_fn(set_hparams_fn, state_type) def test_set_hparams_fn_with_incompatible_input_state_type_raises(self): - state_type = computation_types.TensorType(np.int32) - hparams_type = computation_types.to_type( + state_type = federated_language.TensorType(np.int32) + hparams_type = federated_language.to_type( collections.OrderedDict(a=np.int32) ) @@ -94,8 +93,8 @@ def set_hparams_fn(state, hparams): hparams_base.type_check_set_hparams_fn(set_hparams_fn, state_type) def test_set_hparams_fn_with_incompatible_outputput_state_type_raises(self): - state_type = computation_types.TensorType(np.int32) - hparams_type = computation_types.to_type( + state_type = federated_language.TensorType(np.int32) + hparams_type = federated_language.to_type( collections.OrderedDict(a=np.float32) ) @@ -108,31 +107,33 @@ def set_hparams_fn(state, hparams): hparams_base.type_check_set_hparams_fn(set_hparams_fn, state_type) def test_default_get_hparams_returns_empty_dict(self): - state_type = computation_types.TensorType(np.int32) + state_type = federated_language.TensorType(np.int32) get_hparams_fn = hparams_base.build_basic_hparams_getter(state_type) - expected_hparams_type = computation_types.to_type(collections.OrderedDict()) - expected_function_type = computation_types.FunctionType( + expected_hparams_type = federated_language.to_type( + collections.OrderedDict() + ) + expected_function_type = federated_language.FunctionType( parameter=state_type, result=expected_hparams_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( get_hparams_fn.type_signature, expected_function_type ) def test_default_set_hparams_returns_state_of_matching_type(self): - state_type = computation_types.TensorType(np.int32) - hparams_type = computation_types.to_type( + state_type = federated_language.TensorType(np.int32) + hparams_type = federated_language.to_type( collections.OrderedDict(a=np.float32) ) set_hparams_fn = hparams_base.build_basic_hparams_setter( state_type, hparams_type ) - expected_function_type = computation_types.FunctionType( - parameter=computation_types.StructType( + expected_function_type = federated_language.FunctionType( + parameter=federated_language.StructType( [('state', state_type), ('hparams', hparams_type)] ), result=state_type, ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( set_hparams_fn.type_signature, expected_function_type ) diff --git a/tensorflow_federated/python/learning/templates/learning_process.py b/tensorflow_federated/python/learning/templates/learning_process.py index 481fd6ead3..329654d55c 100644 --- a/tensorflow_federated/python/learning/templates/learning_process.py +++ b/tensorflow_federated/python/learning/templates/learning_process.py @@ -16,10 +16,9 @@ import typing from typing import Any, NamedTuple, Optional, Union +import federated_language + from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.learning.templates import hparams_base @@ -91,13 +90,13 @@ class LearningProcess(iterative_process.IterativeProcess): def __init__( self, - initialize_fn: computation_base.Computation, - next_fn: computation_base.Computation, - get_model_weights: computation_base.Computation, - set_model_weights: computation_base.Computation, + initialize_fn: federated_language.framework.Computation, + next_fn: federated_language.framework.Computation, + get_model_weights: federated_language.framework.Computation, + set_model_weights: federated_language.framework.Computation, *, - get_hparams_fn: Optional[computation_base.Computation] = None, - set_hparams_fn: Optional[computation_base.Computation] = None, + get_hparams_fn: Optional[federated_language.framework.Computation] = None, + set_hparams_fn: Optional[federated_language.framework.Computation] = None, ): """Creates a `tff.learning.templates.LearningProcess`. @@ -178,7 +177,7 @@ def __init__( super().__init__(initialize_fn, next_fn) init_fn_result = initialize_fn.type_signature.result - if init_fn_result.placement != placements.SERVER: # pytype: disable=attribute-error + if init_fn_result.placement != federated_language.SERVER: # pytype: disable=attribute-error raise LearningProcessPlacementError( 'The result of `initialize_fn` must be placed at `SERVER` but found ' f'placement {init_fn_result.placement}.' # pytype: disable=attribute-error @@ -187,10 +186,10 @@ def __init__( next_result_type = next_fn.type_signature.result # TODO: b/224484886 - Downcasting to all handled types. next_result_type = typing.cast( - Union[computation_types.StructWithPythonType], next_result_type + Union[federated_language.StructWithPythonType], next_result_type ) if not ( - isinstance(next_result_type, computation_types.StructWithPythonType) + isinstance(next_result_type, federated_language.StructWithPythonType) and next_result_type.python_container is LearningProcessOutput ): raise LearningProcessOutputError( @@ -201,30 +200,32 @@ def __init__( # base class. # TODO: b/224484886 - Downcasting to all handled types. next_fn_param = typing.cast( - Union[computation_types.StructType], next_fn.type_signature.parameter + Union[federated_language.StructType], next_fn.type_signature.parameter ) if ( - not isinstance(next_fn_param, computation_types.StructType) + not isinstance(next_fn_param, federated_language.StructType) or len(next_fn_param) != 2 ): raise errors.TemplateNextFnNumArgsError( 'The `next_fn` must have two input arguments, but found an input ' f'of type {next_fn_param}.' ) - if next_fn_param[1].placement != placements.CLIENTS: + if next_fn_param[1].placement != federated_language.CLIENTS: raise LearningProcessPlacementError( 'The second input argument of `next_fn` must be placed at `CLIENTS`,' f' but found placement {next_fn_param[1].placement}.' ) next_fn_result = next_fn.type_signature.result - if next_fn_result.metrics.placement != placements.SERVER: # pytype: disable=attribute-error + if next_fn_result.metrics.placement != federated_language.SERVER: # pytype: disable=attribute-error raise LearningProcessPlacementError( 'The result of `next_fn` must be placed at `SERVER` but found ' f'placement {next_fn_result.metrics.placement} for `metrics`.' # pytype: disable=attribute-error ) - py_typecheck.check_type(get_model_weights, computation_base.Computation) + py_typecheck.check_type( + get_model_weights, federated_language.framework.Computation + ) get_model_weights_type = get_model_weights.type_signature get_model_weights_param = get_model_weights_type.parameter next_fn_state_param = next_fn.type_signature.parameter[0].member # pytype: disable=unsupported-operands @@ -240,7 +241,9 @@ def __init__( ) self._get_model_weights = get_model_weights - py_typecheck.check_type(set_model_weights, computation_base.Computation) + py_typecheck.check_type( + set_model_weights, federated_language.framework.Computation + ) set_model_weights_type = set_model_weights.type_signature set_model_weights_state_param = set_model_weights_type.parameter[0] # pytype: disable=unsupported-operands if not set_model_weights_state_param.is_equivalent_to(next_fn_state_param): @@ -279,7 +282,7 @@ def __init__( self._set_hparams_fn = set_hparams_fn @property - def initialize(self) -> computation_base.Computation: + def initialize(self) -> federated_language.framework.Computation: """A `tff.Computation` that initializes the process. This computation must have no input arguments, and its output must be the @@ -291,7 +294,7 @@ def initialize(self) -> computation_base.Computation: return super().initialize @property - def next(self) -> computation_base.Computation: + def next(self) -> federated_language.framework.Computation: """A `tff.Computation` that runs one iteration of the process. The first argument of this computation should always be the current state @@ -305,7 +308,7 @@ def next(self) -> computation_base.Computation: return super().next @property - def get_model_weights(self) -> computation_base.Computation: + def get_model_weights(self) -> federated_language.framework.Computation: """A `tff.Computation` returning the model weights of a server state. This computation accepts an unplaced state of the process (originally @@ -321,7 +324,7 @@ def get_model_weights(self) -> computation_base.Computation: return self._get_model_weights @property - def set_model_weights(self) -> computation_base.Computation: + def set_model_weights(self) -> federated_language.framework.Computation: """A `tff.Computation` that sets the model weights of a server state. This computation accepts two arguments: an unplaced state of the process @@ -338,7 +341,7 @@ def set_model_weights(self) -> computation_base.Computation: return self._set_model_weights @property - def get_hparams(self) -> computation_base.Computation: + def get_hparams(self) -> federated_language.framework.Computation: """A `tff.Computation` returning the hyperparameters of a server state. This computation accepts an unplaced state of the process (originally @@ -351,7 +354,7 @@ def get_hparams(self) -> computation_base.Computation: return self._get_hparams_fn @property - def set_hparams(self) -> computation_base.Computation: + def set_hparams(self) -> federated_language.framework.Computation: """A `tff.Computation` that sets the hyperparamters of a server state. This computation accepts two arguments: an unplaced state of the process diff --git a/tensorflow_federated/python/learning/templates/learning_process_test.py b/tensorflow_federated/python/learning/templates/learning_process_test.py index 1748c24848..2adbd15d93 100644 --- a/tensorflow_federated/python/learning/templates/learning_process_test.py +++ b/tensorflow_federated/python/learning/templates/learning_process_test.py @@ -15,22 +15,19 @@ import collections from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.learning.templates import learning_process # Convenience aliases. LearningProcessOutput = learning_process.LearningProcessOutput -SequenceType = computation_types.SequenceType -TensorType = computation_types.TensorType -federated_computation = federated_computation.federated_computation +SequenceType = federated_language.SequenceType +TensorType = federated_language.TensorType +federated_computation = federated_language.federated_computation tf_computation = tensorflow_computation.tf_computation _LearningProcessConstructionError = ( @@ -64,7 +61,7 @@ def take_arg_set_model_weights(state, model_weights): @federated_computation def test_init_fn(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) @tf_computation(SequenceType(np.int32)) @@ -79,21 +76,24 @@ def sum_dataset(dataset): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(SequenceType(np.int32), placements.CLIENTS), + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS + ), ) def test_next_fn(state, data): - client_sums = intrinsics.federated_map(sum_dataset, data) - server_sum = intrinsics.federated_sum(client_sums) + client_sums = federated_language.federated_map(sum_dataset, data) + server_sum = federated_language.federated_sum(client_sums) @tf_computation def add(x, y): """Function to hide `tf.add`'s `name` parameter from TFF.""" return tf.add(x, y) - result = intrinsics.federated_map(add, (state, server_sum)) + result = federated_language.federated_map(add, (state, server_sum)) return LearningProcessOutput( - state=result, metrics=intrinsics.federated_value((), placements.SERVER) + state=result, + metrics=federated_language.federated_value((), federated_language.SERVER), ) @@ -138,19 +138,25 @@ def test_construction_with_empty_state_does_not_raise(self): @federated_computation def empty_initialize_fn(): - return intrinsics.federated_value(empty_tuple, placements.SERVER) + return federated_language.federated_value( + empty_tuple, federated_language.SERVER + ) @federated_computation( - computation_types.FederatedType(empty_tuple, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType( + empty_tuple, federated_language.SERVER + ), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), ) def next_fn(state, value): del value # Unused. return LearningProcessOutput( state=state, - metrics=intrinsics.federated_value(empty_tuple, placements.SERVER), + metrics=federated_language.federated_value( + empty_tuple, federated_language.SERVER + ), ) try: @@ -166,8 +172,9 @@ def next_fn(state, value): def test_construction_with_unknown_dimension_does_not_raise(self): @federated_computation def initialize_fn(): - return intrinsics.federated_eval( - tf_computation(lambda: tf.constant([], tf.string)), placements.SERVER + return federated_language.federated_eval( + tf_computation(lambda: tf.constant([], tf.string)), + federated_language.SERVER, ) # This replicates a tensor that can grow in string length. The @@ -176,17 +183,18 @@ def initialize_fn(): none_dimension_string_type = TensorType(np.str_, [None]) @federated_computation( - computation_types.FederatedType( - none_dimension_string_type, placements.SERVER + federated_language.FederatedType( + none_dimension_string_type, federated_language.SERVER ), - computation_types.FederatedType( - SequenceType(np.str_), placements.CLIENTS + federated_language.FederatedType( + SequenceType(np.str_), federated_language.CLIENTS ), ) def next_fn(state, datasets): del datasets # Unused. return LearningProcessOutput( - state, intrinsics.federated_value((), placements.SERVER) + state, + federated_language.federated_value((), federated_language.SERVER), ) try: @@ -207,9 +215,9 @@ def next_fn(state, datasets): def test_construction_with_nested_datasets_does_not_raise(self): @federated_computation def initialize_fn(): - return intrinsics.federated_eval( + return federated_language.federated_eval( tf_computation(lambda: tf.constant(0.0, tf.float32)), - placements.SERVER, + federated_language.SERVER, ) # Test that clients can receive multiple datasets. @@ -219,13 +227,16 @@ def initialize_fn(): ) @federated_computation( - computation_types.FederatedType(np.float32, placements.SERVER), - computation_types.FederatedType(datasets_type, placements.CLIENTS), + federated_language.FederatedType(np.float32, federated_language.SERVER), + federated_language.FederatedType( + datasets_type, federated_language.CLIENTS + ), ) def next_fn(state, datasets): del datasets # Unused. return LearningProcessOutput( - state, intrinsics.federated_value((), placements.SERVER) + state, + federated_language.federated_value((), federated_language.SERVER), ) try: @@ -263,7 +274,7 @@ def test_next_not_tff_computation_raises(self): def test_init_param_not_empty_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def one_arg_initialize_fn(x): return x @@ -303,12 +314,12 @@ def test_next_state_not_federated(self): def test_init_fn_with_client_placed_state_raises(self): @federated_computation def init_fn(): - return intrinsics.federated_value(0, placements.CLIENTS) + return federated_language.federated_value(0, federated_language.CLIENTS) @federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.CLIENTS), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), ) def next_fn(state, client_values): @@ -322,14 +333,14 @@ def next_fn(state, client_values): def test_next_return_tuple_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), ) def tuple_next_fn(state, client_values): - metrics = intrinsics.federated_map(sum_dataset, client_values) - metrics = intrinsics.federated_sum(metrics) + metrics = federated_language.federated_map(sum_dataset, client_values) + metrics = federated_language.federated_sum(metrics) return (state, metrics) with self.assertRaises(learning_process.LearningProcessOutputError): @@ -346,14 +357,14 @@ def test_next_return_namedtuple_raises(self): ) @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), ) def namedtuple_next_fn(state, client_values): - metrics = intrinsics.federated_map(sum_dataset, client_values) - metrics = intrinsics.federated_sum(metrics) + metrics = federated_language.federated_map(sum_dataset, client_values) + metrics = federated_language.federated_sum(metrics) return learning_process_output(state, metrics) with self.assertRaises(learning_process.LearningProcessOutputError): @@ -367,14 +378,14 @@ def namedtuple_next_fn(state, client_values): def test_next_return_odict_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), ) def odict_next_fn(state, client_values): - metrics = intrinsics.federated_map(sum_dataset, client_values) - metrics = intrinsics.federated_sum(metrics) + metrics = federated_language.federated_map(sum_dataset, client_values) + metrics = federated_language.federated_sum(metrics) return collections.OrderedDict(state=state, metrics=metrics) with self.assertRaises(learning_process.LearningProcessOutputError): @@ -388,7 +399,7 @@ def odict_next_fn(state, client_values): def test_next_fn_with_one_parameter_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + federated_language.FederatedType(np.int32, federated_language.SERVER) ) def next_fn(state): return LearningProcessOutput(state, 0) @@ -404,16 +415,16 @@ def next_fn(state): def test_next_fn_with_three_parameters_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), ) def next_fn(state, client_values, second_state): del second_state # Unused. - metrics = intrinsics.federated_map(sum_dataset, client_values) - metrics = intrinsics.federated_sum(metrics) + metrics = federated_language.federated_map(sum_dataset, client_values) + metrics = federated_language.federated_sum(metrics) return LearningProcessOutput(state, metrics) with self.assertRaises(errors.TemplateNextFnNumArgsError): @@ -427,13 +438,13 @@ def next_fn(state, client_values, second_state): def test_next_fn_with_server_placed_second_arg_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.SERVER + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.SERVER ), ) def next_fn(state, server_values): - metrics = intrinsics.federated_map(sum_dataset, server_values) + metrics = federated_language.federated_map(sum_dataset, server_values) return LearningProcessOutput(state, metrics) with self.assertRaises(learning_process.LearningProcessPlacementError): @@ -447,9 +458,9 @@ def next_fn(state, server_values): def test_next_fn_with_client_placed_metrics_result_raises(self): @federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType( - SequenceType(np.int32), placements.CLIENTS + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType( + SequenceType(np.int32), federated_language.CLIENTS ), ) def next_fn(state, metrics): @@ -474,8 +485,8 @@ def test_non_tff_computation_get_model_weights_raises(self): ) def test_non_functional_get_model_weights_raises(self): - get_model_weights = computation_types.FederatedType( - np.int32, placements.SERVER + get_model_weights = federated_language.FederatedType( + np.int32, federated_language.SERVER ) with self.assertRaises(TypeError): learning_process.LearningProcess( @@ -487,7 +498,7 @@ def test_non_functional_get_model_weights_raises(self): def test_federated_get_model_weights_raises(self): bad_get_model_weights = create_pass_through_get_model_weights( - computation_types.FederatedType(np.float32, placements.SERVER) + federated_language.FederatedType(np.float32, federated_language.SERVER) ) with self.assertRaises(learning_process.GetModelWeightsTypeSignatureError): learning_process.LearningProcess( diff --git a/tensorflow_federated/python/learning/templates/model_delta_client_work.py b/tensorflow_federated/python/learning/templates/model_delta_client_work.py index 048d295f4e..1217976d33 100644 --- a/tensorflow_federated/python/learning/templates/model_delta_client_work.py +++ b/tensorflow_federated/python/learning/templates/model_delta_client_work.py @@ -23,15 +23,12 @@ from collections.abc import Callable, Mapping from typing import Any, Optional +import federated_language import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder @@ -203,7 +200,7 @@ def build_model_delta_client_work( model = model_fn() metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers()) element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) weights_type = model_weights_lib.weights_type_from_model(model) # We initialize the optimizer for the purposes of extracting its @@ -212,9 +209,11 @@ def build_model_delta_client_work( whimsy_opt_state = optimizer.initialize(whimsy_specs) initial_hparams = optimizer.get_hparams(whimsy_opt_state) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value(initial_hparams, placements.SERVER) + return federated_language.federated_value( + initial_hparams, federated_language.SERVER + ) state_type = init_fn.type_signature.result.member # In this case, the state is exactly equal to the hyperparameters being @@ -242,18 +241,20 @@ def client_update_computation(state, initial_model_weights, dataset): optimizer, initial_model_weights, dataset, optimizer_hparams ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): - state_at_clients = intrinsics.federated_broadcast(state) - client_result, model_outputs = intrinsics.federated_map( + state_at_clients = federated_language.federated_broadcast(state) + client_result, model_outputs = federated_language.federated_map( client_update_computation, (state_at_clients, weights, client_data) ) train_metrics = metrics_aggregation_fn(model_outputs) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( @@ -418,7 +419,7 @@ def build_functional_model_delta_client_work( py_typecheck.check_type(optimizer, optimizer_base.Optimizer) py_typecheck.check_type(client_weighting, client_weight_lib.ClientWeighting) element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) def ndarray_to_tensorspec(ndarray): return tf.TensorSpec( @@ -446,26 +447,28 @@ def client_update_computation(initial_model_weights, dataset): ) return client_update(optimizer, initial_model_weights, dataset) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): # Empty tuple means "no state" / stateless. - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) if metrics_aggregator is None: metrics_aggregator = aggregator.sum_then_finalize - @federated_computation.federated_computation( - computation_types.FederatedType((), placements.SERVER), - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + @federated_language.federated_computation( + federated_language.FederatedType((), federated_language.SERVER), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): - client_result, unfinalized_metrics = intrinsics.federated_map( + client_result, unfinalized_metrics = federated_language.federated_map( client_update_computation, (weights, client_data) ) metrics_aggregation_fn = metrics_aggregator(model.finalize_metrics) finalized_training_metrics = metrics_aggregation_fn(unfinalized_metrics) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=finalized_training_metrics) ) return measured_process.MeasuredProcessOutput( diff --git a/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py b/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py index 3c3f9bbe22..b7a9014b2c 100644 --- a/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py +++ b/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py @@ -17,17 +17,13 @@ from unittest import mock from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder @@ -59,13 +55,14 @@ def test_initialize_has_expected_type_signature_with_tff_optimizer( model_fn, optimizer, weighting ) - expected_state_type = computation_types.FederatedType( - collections.OrderedDict(learning_rate=np.float32), placements.SERVER + expected_state_type = federated_language.FederatedType( + collections.OrderedDict(learning_rate=np.float32), + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( client_work_process.initialize.type_signature, expected_initialize_type ) @@ -82,35 +79,39 @@ def test_next_has_expected_type_signature_with_tff_optimizer(self, weighting): ) mw_type = model_weights_lib.ModelWeights( - trainable=computation_types.to_type([(np.float32, (2, 1)), np.float32]), - non_trainable=computation_types.to_type([np.float32]), + trainable=federated_language.to_type( + [(np.float32, (2, 1)), np.float32] + ), + non_trainable=federated_language.to_type([np.float32]), ) - expected_param_model_weights_type = computation_types.FederatedType( - mw_type, placements.CLIENTS + expected_param_model_weights_type = federated_language.FederatedType( + mw_type, federated_language.CLIENTS ) element_type = tensorflow_types.to_type(model_fn().input_spec) - expected_param_data_type = computation_types.FederatedType( - computation_types.SequenceType(element_type), placements.CLIENTS + expected_param_data_type = federated_language.FederatedType( + federated_language.SequenceType(element_type), + federated_language.CLIENTS, ) - expected_result_type = computation_types.FederatedType( + expected_result_type = federated_language.FederatedType( client_works.ClientResult( update=mw_type.trainable, - update_weight=computation_types.TensorType(np.float32), + update_weight=federated_language.TensorType(np.float32), ), - placements.CLIENTS, + federated_language.CLIENTS, ) - expected_state_type = computation_types.FederatedType( - collections.OrderedDict(learning_rate=np.float32), placements.SERVER + expected_state_type = federated_language.FederatedType( + collections.OrderedDict(learning_rate=np.float32), + federated_language.SERVER, ) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( train=collections.OrderedDict( loss=np.float32, num_examples=np.int32 ) ), - placements.SERVER, + federated_language.SERVER, ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_model_weights_type, @@ -122,7 +123,7 @@ def test_next_has_expected_type_signature_with_tff_optimizer(self, weighting): expected_measurements_type, ), ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( client_work_process.next.type_signature, expected_next_type ) @@ -142,10 +143,10 @@ def test_get_hparams_has_expected_type_signature_with_tff_optimizer( expected_state_type = collections.OrderedDict(learning_rate=np.float32) expected_hparams_type = expected_state_type - expected_get_hparams_type = computation_types.FunctionType( + expected_get_hparams_type = federated_language.FunctionType( parameter=expected_state_type, result=expected_hparams_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( client_work_process.get_hparams.type_signature, expected_get_hparams_type, ) @@ -166,13 +167,13 @@ def test_set_hparams_has_expected_type_signature_with_tff_optimizer( expected_state_type = collections.OrderedDict(learning_rate=np.float32) expected_hparams_type = expected_state_type - expected_parameter_type = computation_types.StructType( + expected_parameter_type = federated_language.StructType( [('state', expected_state_type), ('hparams', expected_hparams_type)] ) - expected_set_hparams_type = computation_types.FunctionType( + expected_set_hparams_type = federated_language.FunctionType( parameter=expected_parameter_type, result=expected_state_type ) - type_test_utils.assert_types_equivalent( + federated_language.framework.assert_types_equivalent( client_work_process.set_hparams.type_signature, expected_set_hparams_type, ) @@ -319,7 +320,7 @@ def test_correct_update_weight_with_traced_function(self): y=[[0.0], [0.0], [1.0], [1.0]], ) ).batch(1) - # Obtain a concrete function after tracing. + # Obtain a concrete function after federated_language.framework. client_concrete_fn = client_tf.get_concrete_function( optimizer, init_weights, dataset_wo_nan, optimizer_hparams=None ) @@ -339,9 +340,10 @@ def test_correct_update_weight_with_traced_function(self): def test_custom_metrics_aggregator(self): def sum_then_finalize_then_times_two(metric_finalizers): - @federated_computation.federated_computation + + @federated_language.federated_computation def aggregation_computation(client_local_unfinalized_metrics): - unfinalized_metrics_sum = intrinsics.federated_sum( + unfinalized_metrics_sum = federated_language.federated_sum( client_local_unfinalized_metrics ) @@ -354,7 +356,7 @@ def finalizer_computation(unfinalized_metrics): ) return finalized_metrics - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, unfinalized_metrics_sum ) diff --git a/tensorflow_federated/python/learning/templates/proximal_client_work.py b/tensorflow_federated/python/learning/templates/proximal_client_work.py index 201fd33aa2..f568d4621c 100644 --- a/tensorflow_federated/python/learning/templates/proximal_client_work.py +++ b/tensorflow_federated/python/learning/templates/proximal_client_work.py @@ -24,15 +24,12 @@ from typing import Any, Optional from absl import logging +import federated_language import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder @@ -345,7 +342,7 @@ def build_model_delta_client_work( model.metric_finalizers(), ) element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) weights_type = model_weights_lib.weights_type_from_model(model) @tensorflow_computation.tf_computation(weights_type, data_type) @@ -358,21 +355,23 @@ def client_update_computation(initial_model_weights, dataset): ) return client_update(optimizer, initial_model_weights, dataset) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): - return intrinsics.federated_value((), placements.SERVER) + return federated_language.federated_value((), federated_language.SERVER) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): - client_result, model_outputs = intrinsics.federated_map( + client_result, model_outputs = federated_language.federated_map( client_update_computation, (weights, client_data) ) train_metrics = metrics_aggregation_fn(model_outputs) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( @@ -429,7 +428,7 @@ def build_functional_model_delta_client_work( if metrics_aggregator is None: metrics_aggregator = aggregator.sum_then_finalize element_type = tensorflow_types.to_type(model.input_spec) - data_type = computation_types.SequenceType(element_type) + data_type = federated_language.SequenceType(element_type) def ndarray_to_tensorspec(ndarray): return tf.TensorSpec(shape=ndarray.shape, dtype=ndarray.dtype) @@ -456,25 +455,29 @@ def client_update_computation(initial_model_weights, dataset): ) return client_update(optimizer, initial_model_weights, dataset) - @federated_computation.federated_computation + @federated_language.federated_computation def init_fn(): empty_state = () - return intrinsics.federated_value(empty_state, placements.SERVER) + return federated_language.federated_value( + empty_state, federated_language.SERVER + ) - @federated_computation.federated_computation( + @federated_language.federated_computation( init_fn.type_signature.result, - computation_types.FederatedType(weights_type, placements.CLIENTS), - computation_types.FederatedType(data_type, placements.CLIENTS), + federated_language.FederatedType( + weights_type, federated_language.CLIENTS + ), + federated_language.FederatedType(data_type, federated_language.CLIENTS), ) def next_fn(state, weights, client_data): - client_result, unfinalized_metrics = intrinsics.federated_map( + client_result, unfinalized_metrics = federated_language.federated_map( client_update_computation, (weights, client_data) ) metrics_aggregation_fn = metrics_aggregator( model.finalize_metrics, unfinalized_metrics.type_signature.member ) train_metrics = metrics_aggregation_fn(unfinalized_metrics) - measurements = intrinsics.federated_zip( + measurements = federated_language.federated_zip( collections.OrderedDict(train=train_metrics) ) return measured_process.MeasuredProcessOutput( diff --git a/tensorflow_federated/python/learning/templates/proximal_client_work_test.py b/tensorflow_federated/python/learning/templates/proximal_client_work_test.py index 798774e32a..8ba6db15cc 100644 --- a/tensorflow_federated/python/learning/templates/proximal_client_work_test.py +++ b/tensorflow_federated/python/learning/templates/proximal_client_work_test.py @@ -16,16 +16,13 @@ from unittest import mock from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder @@ -66,41 +63,46 @@ def test_type_properties(self, optimizer, weighting): self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) mw_type = model_weights.ModelWeights( - trainable=computation_types.to_type([(np.float32, (2, 1)), np.float32]), - non_trainable=computation_types.to_type([np.float32]), + trainable=federated_language.to_type( + [(np.float32, (2, 1)), np.float32] + ), + non_trainable=federated_language.to_type([np.float32]), ) - expected_param_model_weights_type = computation_types.FederatedType( - mw_type, placements.CLIENTS + expected_param_model_weights_type = federated_language.FederatedType( + mw_type, federated_language.CLIENTS ) element_type = tensorflow_types.to_type(model_fn().input_spec) - expected_param_data_type = computation_types.FederatedType( - computation_types.SequenceType(element_type), placements.CLIENTS + expected_param_data_type = federated_language.FederatedType( + federated_language.SequenceType(element_type), + federated_language.CLIENTS, ) - expected_result_type = computation_types.FederatedType( + expected_result_type = federated_language.FederatedType( client_works.ClientResult( update=mw_type.trainable, - update_weight=computation_types.TensorType(np.float32), + update_weight=federated_language.TensorType(np.float32), ), - placements.CLIENTS, + federated_language.CLIENTS, + ) + expected_state_type = federated_language.FederatedType( + (), federated_language.SERVER ) - expected_state_type = computation_types.FederatedType((), placements.SERVER) - expected_measurements_type = computation_types.FederatedType( + expected_measurements_type = federated_language.FederatedType( collections.OrderedDict( train=collections.OrderedDict( loss=np.float32, num_examples=np.int32 ) ), - placements.SERVER, + federated_language.SERVER, ) - expected_initialize_type = computation_types.FunctionType( + expected_initialize_type = federated_language.FunctionType( parameter=None, result=expected_state_type ) expected_initialize_type.check_equivalent_to( client_work_process.initialize.type_signature ) - expected_next_type = computation_types.FunctionType( + expected_next_type = federated_language.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_model_weights_type, @@ -239,9 +241,10 @@ def test_non_finite_aggregation(self, bad_value): def test_custom_metrics_aggregator(self): def sum_then_finalize_then_times_two(metric_finalizers): - @federated_computation.federated_computation + + @federated_language.federated_computation def aggregation_computation(client_local_unfinalized_metrics): - unfinalized_metrics_sum = intrinsics.federated_sum( + unfinalized_metrics_sum = federated_language.federated_sum( client_local_unfinalized_metrics ) @@ -254,7 +257,7 @@ def finalizer_computation(unfinalized_metrics): ) return finalized_metrics - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, unfinalized_metrics_sum ) @@ -433,13 +436,13 @@ def sum_then_finalize_then_times_two( metric_finalizers, local_unfinalized_metrics_type ): - @federated_computation.federated_computation( - computation_types.FederatedType( - local_unfinalized_metrics_type, placements.CLIENTS + @federated_language.federated_computation( + federated_language.FederatedType( + local_unfinalized_metrics_type, federated_language.CLIENTS ) ) def aggregation_computation(client_local_unfinalized_metrics): - unfinalized_metrics_sum = intrinsics.federated_sum( + unfinalized_metrics_sum = federated_language.federated_sum( client_local_unfinalized_metrics ) @@ -449,7 +452,7 @@ def finalizer_computation(unfinalized_metrics): lambda x: x * 2, metric_finalizers(unfinalized_metrics) ) - return intrinsics.federated_map( + return federated_language.federated_map( finalizer_computation, unfinalized_metrics_sum ) diff --git a/tensorflow_federated/python/learning/templates/type_checks.py b/tensorflow_federated/python/learning/templates/type_checks.py index 3f32501949..5d1b645971 100644 --- a/tensorflow_federated/python/learning/templates/type_checks.py +++ b/tensorflow_federated/python/learning/templates/type_checks.py @@ -15,9 +15,7 @@ from typing import Optional -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis +import federated_language class ClientSequenceTypeError(Exception): @@ -25,7 +23,7 @@ class ClientSequenceTypeError(Exception): def check_is_client_placed_structure_of_sequences( - type_spec: computation_types.Type, error_message: Optional[str] = None + type_spec: federated_language.Type, error_message: Optional[str] = None ) -> None: """Checks that a type is a structure of sequences, placed at `tff.CLIENTS`. @@ -39,10 +37,12 @@ def check_is_client_placed_structure_of_sequences( if its member type is not a structure of TensorFlow-compatible sequences. """ - def is_structure_of_sequences(member_spec: computation_types.Type) -> bool: - if isinstance(member_spec, computation_types.SequenceType): - return type_analysis.is_tensorflow_compatible_type(member_spec.element) - elif isinstance(member_spec, computation_types.StructType): + def is_structure_of_sequences(member_spec: federated_language.Type) -> bool: + if isinstance(member_spec, federated_language.SequenceType): + return federated_language.framework.is_tensorflow_compatible_type( + member_spec.element + ) + elif isinstance(member_spec, federated_language.StructType): return all( is_structure_of_sequences(element_type) for element_type in member_spec.children() @@ -58,8 +58,8 @@ def is_structure_of_sequences(member_spec: computation_types.Type) -> bool: ) if ( - not isinstance(type_spec, computation_types.FederatedType) - or type_spec.placement is not placements.CLIENTS + not isinstance(type_spec, federated_language.FederatedType) + or type_spec.placement is not federated_language.CLIENTS or not is_structure_of_sequences(type_spec.member) ): raise ClientSequenceTypeError(error_message) diff --git a/tensorflow_federated/python/learning/templates/type_checks_test.py b/tensorflow_federated/python/learning/templates/type_checks_test.py index 492d54ce74..161549e6c4 100644 --- a/tensorflow_federated/python/learning/templates/type_checks_test.py +++ b/tensorflow_federated/python/learning/templates/type_checks_test.py @@ -12,71 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. +import federated_language import numpy as np import tensorflow as tf - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.learning.templates import type_checks class TypeChecksTest(tf.test.TestCase): def test_does_not_raise_on_client_placed_sequence(self): - sequence_type = computation_types.SequenceType(np.int32) - type_spec = computation_types.FederatedType( - sequence_type, placements.CLIENTS + sequence_type = federated_language.SequenceType(np.int32) + type_spec = federated_language.FederatedType( + sequence_type, federated_language.CLIENTS ) type_checks.check_is_client_placed_structure_of_sequences(type_spec) def test_does_not_raise_on_client_placed_struct_of_sequences(self): - sequence_type1 = computation_types.SequenceType(np.int32) - sequence_type2 = computation_types.SequenceType(np.float32) - struct_type = computation_types.StructWithPythonType( + sequence_type1 = federated_language.SequenceType(np.int32) + sequence_type2 = federated_language.SequenceType(np.float32) + struct_type = federated_language.StructWithPythonType( [sequence_type1, sequence_type2], list ) - type_spec = computation_types.FederatedType(struct_type, placements.CLIENTS) + type_spec = federated_language.FederatedType( + struct_type, federated_language.CLIENTS + ) type_checks.check_is_client_placed_structure_of_sequences(type_spec) def test_raises_on_server_placed_sequence(self): - sequence_type = computation_types.SequenceType(np.int32) - type_spec = computation_types.FederatedType( - sequence_type, placements.SERVER + sequence_type = federated_language.SequenceType(np.int32) + type_spec = federated_language.FederatedType( + sequence_type, federated_language.SERVER ) with self.assertRaises(type_checks.ClientSequenceTypeError): type_checks.check_is_client_placed_structure_of_sequences(type_spec) def test_raises_on_server_placed_struct_of_sequences(self): - sequence_type1 = computation_types.SequenceType(np.int32) - sequence_type2 = computation_types.SequenceType(np.float32) - struct_type = computation_types.StructWithPythonType( + sequence_type1 = federated_language.SequenceType(np.int32) + sequence_type2 = federated_language.SequenceType(np.float32) + struct_type = federated_language.StructWithPythonType( [sequence_type1, sequence_type2], list ) - type_spec = computation_types.FederatedType(struct_type, placements.SERVER) + type_spec = federated_language.FederatedType( + struct_type, federated_language.SERVER + ) with self.assertRaises(type_checks.ClientSequenceTypeError): type_checks.check_is_client_placed_structure_of_sequences(type_spec) def test_raises_on_client_placed_tensor(self): - tensor_spec = computation_types.TensorType(np.int32, (1, 2)) - type_spec = computation_types.FederatedType(tensor_spec, placements.CLIENTS) + tensor_spec = federated_language.TensorType(np.int32, (1, 2)) + type_spec = federated_language.FederatedType( + tensor_spec, federated_language.CLIENTS + ) with self.assertRaises(type_checks.ClientSequenceTypeError): type_checks.check_is_client_placed_structure_of_sequences(type_spec) def test_raises_on_client_placed_structure_of_tensor_and_sequence(self): - tensor_spec = computation_types.TensorType(np.int32, (1, 2)) - sequence_type = computation_types.SequenceType(np.int32) - struct_type = computation_types.StructWithPythonType( + tensor_spec = federated_language.TensorType(np.int32, (1, 2)) + sequence_type = federated_language.SequenceType(np.int32) + struct_type = federated_language.StructWithPythonType( [tensor_spec, sequence_type], list ) - type_spec = computation_types.FederatedType(struct_type, placements.CLIENTS) + type_spec = federated_language.FederatedType( + struct_type, federated_language.CLIENTS + ) with self.assertRaises(type_checks.ClientSequenceTypeError): type_checks.check_is_client_placed_structure_of_sequences(type_spec) def test_raises_on_structure_of_client_placed_sequences(self): - clients_sequence_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + clients_sequence_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) - type_spec = computation_types.StructType([ + type_spec = federated_language.StructType([ (None, clients_sequence_type), (None, clients_sequence_type), ]) diff --git a/tensorflow_federated/python/program/BUILD b/tensorflow_federated/python/program/BUILD index 5d28425f72..150eb07db5 100644 --- a/tensorflow_federated/python/program/BUILD +++ b/tensorflow_federated/python/program/BUILD @@ -23,18 +23,12 @@ py_library( visibility = ["//tensorflow_federated:__pkg__"], deps = [ ":client_id_data_source", - ":data_source", ":dataset_data_source", - ":federated_context", ":file_program_state_manager", ":file_release_manager", - ":logging_release_manager", - ":memory_release_manager", ":native_platform", - ":program_state_manager", - ":release_manager", ":tensorboard_release_manager", - ":value_reference", + "@federated_language//federated_language", ], ) @@ -42,10 +36,8 @@ py_library( name = "client_id_data_source", srcs = ["client_id_data_source.py"], deps = [ - ":data_source", ":serialization_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -54,17 +46,7 @@ py_test( srcs = ["client_id_data_source_test.py"], deps = [ ":client_id_data_source", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "data_source", - srcs = ["data_source.py"], - deps = [ - "//tensorflow_federated/python/common_libs:serializable", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -72,11 +54,9 @@ py_library( name = "dataset_data_source", srcs = ["dataset_data_source.py"], deps = [ - ":data_source", ":serialization_utils", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -85,35 +65,7 @@ py_test( srcs = ["dataset_data_source_test.py"], deps = [ ":dataset_data_source", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_library( - name = "federated_context", - srcs = ["federated_context.py"], - deps = [ - ":structure_utils", - ":value_reference", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:get_context_stack", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", - ], -) - -py_test( - name = "federated_context_test", - srcs = ["federated_context_test.py"], - deps = [ - ":federated_context", - "//tensorflow_federated/python/core/impl/context_stack:context_base", - "//tensorflow_federated/python/core/impl/context_stack:context_stack_impl", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) @@ -122,10 +74,8 @@ py_library( srcs = ["file_program_state_manager.py"], deps = [ ":file_utils", - ":program_state_manager", ":structure_utils", - ":value_reference", - "//tensorflow_federated/python/common_libs:serializable", + "@federated_language//federated_language", ], ) @@ -135,8 +85,8 @@ py_test( deps = [ ":file_program_state_manager", ":file_utils", - ":program_state_manager", ":program_test_utils", + "@federated_language//federated_language", ], ) @@ -145,9 +95,8 @@ py_library( srcs = ["file_release_manager.py"], deps = [ ":file_utils", - ":release_manager", ":structure_utils", - ":value_reference", + "@federated_language//federated_language", ], ) @@ -158,8 +107,8 @@ py_test( ":file_release_manager", ":file_utils", ":program_test_utils", - ":release_manager", ":structure_utils", + "@federated_language//federated_language", ], ) @@ -175,55 +124,13 @@ py_test( deps = [":file_utils"], ) -py_library( - name = "logging_release_manager", - srcs = ["logging_release_manager.py"], - deps = [ - ":release_manager", - ":value_reference", - ], -) - -py_test( - name = "logging_release_manager_test", - srcs = ["logging_release_manager_test.py"], - deps = [ - ":logging_release_manager", - ":program_test_utils", - ], -) - -py_library( - name = "memory_release_manager", - srcs = ["memory_release_manager.py"], - deps = [ - ":release_manager", - ":value_reference", - ], -) - -py_test( - name = "memory_release_manager_test", - srcs = ["memory_release_manager_test.py"], - deps = [ - ":memory_release_manager", - ":program_test_utils", - ], -) - py_library( name = "native_platform", srcs = ["native_platform.py"], deps = [ - ":federated_context", ":structure_utils", - ":value_reference", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_conversions", + "@federated_language//federated_language", ], ) @@ -234,61 +141,19 @@ py_test( ":native_platform", ":program_test_utils", ":structure_utils", - ":value_reference", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/execution_contexts:async_execution_context", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", + "@federated_language//federated_language", ], ) -py_library( - name = "program_state_manager", - srcs = ["program_state_manager.py"], - deps = [ - ":structure_utils", - ":value_reference", - "//tensorflow_federated/python/common_libs:serializable", - ], -) - -py_test( - name = "program_state_manager_test", - srcs = ["program_state_manager_test.py"], - deps = [":program_state_manager"], -) - py_library( name = "program_test_utils", testonly = True, srcs = ["program_test_utils.py"], deps = [ - ":value_reference", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:serializable", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - -py_library( - name = "release_manager", - srcs = ["release_manager.py"], - deps = [ - ":structure_utils", - ":value_reference", "//tensorflow_federated/python/common_libs:py_typecheck", - ], -) - -py_test( - name = "release_manager_test", - srcs = ["release_manager_test.py"], - deps = [ - ":program_test_utils", - ":release_manager", + "@federated_language//federated_language", ], ) @@ -297,10 +162,8 @@ py_library( srcs = ["serialization_utils.py"], deps = [ ":structure_utils", - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/common_libs:serializable", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_serialization", + "@federated_language//federated_language", + "@federated_language//federated_language/proto:computation_py_pb2", ], ) @@ -310,7 +173,7 @@ py_test( deps = [ ":program_test_utils", ":serialization_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) @@ -333,9 +196,8 @@ py_library( name = "tensorboard_release_manager", srcs = ["tensorboard_release_manager.py"], deps = [ - ":release_manager", ":structure_utils", - ":value_reference", + "@federated_language//federated_language", ], ) @@ -347,22 +209,3 @@ py_test( ":tensorboard_release_manager", ], ) - -py_library( - name = "value_reference", - srcs = ["value_reference.py"], - deps = [ - ":structure_utils", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:typed_object", - ], -) - -py_test( - name = "value_reference_test", - srcs = ["value_reference_test.py"], - deps = [ - ":program_test_utils", - ":value_reference", - ], -) diff --git a/tensorflow_federated/python/program/__init__.py b/tensorflow_federated/python/program/__init__.py index 3d2f41ba47..ebe33235a5 100644 --- a/tensorflow_federated/python/program/__init__.py +++ b/tensorflow_federated/python/program/__init__.py @@ -13,45 +13,61 @@ # limitations under the License. """Libraries for creating federated programs.""" +import federated_language # pylint: disable=g-importing-member from tensorflow_federated.python.program.client_id_data_source import ClientIdDataSource from tensorflow_federated.python.program.client_id_data_source import ClientIdDataSourceIterator -from tensorflow_federated.python.program.data_source import FederatedDataSource -from tensorflow_federated.python.program.data_source import FederatedDataSourceIterator + +FederatedDataSource = federated_language.program.FederatedDataSource +FederatedDataSourceIterator = ( + federated_language.program.FederatedDataSourceIterator +) from tensorflow_federated.python.program.dataset_data_source import DatasetDataSource from tensorflow_federated.python.program.dataset_data_source import DatasetDataSourceIterator -from tensorflow_federated.python.program.federated_context import check_in_federated_context -from tensorflow_federated.python.program.federated_context import ComputationArg -from tensorflow_federated.python.program.federated_context import contains_only_server_placed_data -from tensorflow_federated.python.program.federated_context import FederatedContext + +check_in_federated_context = ( + federated_language.program.check_in_federated_context +) +ComputationArg = federated_language.program.ComputationArg +contains_only_server_placed_data = ( + federated_language.program.contains_only_server_placed_data +) +FederatedContext = federated_language.program.FederatedContext from tensorflow_federated.python.program.file_program_state_manager import FileProgramStateManager from tensorflow_federated.python.program.file_release_manager import CSVFileReleaseManager from tensorflow_federated.python.program.file_release_manager import CSVKeyFieldnameNotFoundError from tensorflow_federated.python.program.file_release_manager import CSVSaveMode from tensorflow_federated.python.program.file_release_manager import SavedModelFileReleaseManager -from tensorflow_federated.python.program.logging_release_manager import LoggingReleaseManager -from tensorflow_federated.python.program.memory_release_manager import MemoryReleaseManager + +LoggingReleaseManager = federated_language.program.LoggingReleaseManager +MemoryReleaseManager = federated_language.program.MemoryReleaseManager from tensorflow_federated.python.program.native_platform import NativeFederatedContext from tensorflow_federated.python.program.native_platform import NativeValueReference -from tensorflow_federated.python.program.program_state_manager import ProgramStateExistsError -from tensorflow_federated.python.program.program_state_manager import ProgramStateManager -from tensorflow_federated.python.program.program_state_manager import ProgramStateNotFoundError -from tensorflow_federated.python.program.program_state_manager import ProgramStateStructure -from tensorflow_federated.python.program.program_state_manager import ProgramStateValue -from tensorflow_federated.python.program.release_manager import DelayedReleaseManager -from tensorflow_federated.python.program.release_manager import FilteringReleaseManager -from tensorflow_federated.python.program.release_manager import GroupingReleaseManager -from tensorflow_federated.python.program.release_manager import NotFilterableError -from tensorflow_federated.python.program.release_manager import PeriodicReleaseManager -from tensorflow_federated.python.program.release_manager import ReleasableStructure -from tensorflow_federated.python.program.release_manager import ReleasableValue -from tensorflow_federated.python.program.release_manager import ReleaseManager + +ProgramStateExistsError = federated_language.program.ProgramStateExistsError +ProgramStateManager = federated_language.program.ProgramStateManager +ProgramStateNotFoundError = federated_language.program.ProgramStateNotFoundError +ProgramStateStructure = federated_language.program.ProgramStateStructure +ProgramStateValue = federated_language.program.ProgramStateValue +DelayedReleaseManager = federated_language.program.DelayedReleaseManager +FilteringReleaseManager = federated_language.program.FilteringReleaseManager +GroupingReleaseManager = federated_language.program.GroupingReleaseManager +NotFilterableError = federated_language.program.NotFilterableError +PeriodicReleaseManager = federated_language.program.PeriodicReleaseManager +ReleasableStructure = federated_language.program.ReleasableStructure +ReleasableValue = federated_language.program.ReleasableValue +ReleaseManager = federated_language.program.ReleaseManager from tensorflow_federated.python.program.tensorboard_release_manager import TensorBoardReleaseManager -from tensorflow_federated.python.program.value_reference import MaterializableStructure -from tensorflow_federated.python.program.value_reference import MaterializableTypeSignature -from tensorflow_federated.python.program.value_reference import MaterializableValue -from tensorflow_federated.python.program.value_reference import MaterializableValueReference -from tensorflow_federated.python.program.value_reference import materialize_value -from tensorflow_federated.python.program.value_reference import MaterializedStructure -from tensorflow_federated.python.program.value_reference import MaterializedValue + +MaterializableStructure = federated_language.program.MaterializableStructure +MaterializableTypeSignature = ( + federated_language.program.MaterializableTypeSignature +) +MaterializableValue = federated_language.program.MaterializableValue +MaterializableValueReference = ( + federated_language.program.MaterializableValueReference +) +materialize_value = federated_language.program.materialize_value +MaterializedStructure = federated_language.program.MaterializedStructure +MaterializedValue = federated_language.program.MaterializedValue # pylint: enable=g-importing-member diff --git a/tensorflow_federated/python/program/client_id_data_source.py b/tensorflow_federated/python/program/client_id_data_source.py index 7d6b379153..c3de371569 100644 --- a/tensorflow_federated/python/program/client_id_data_source.py +++ b/tensorflow_federated/python/program/client_id_data_source.py @@ -17,15 +17,15 @@ import random from typing import Optional +import federated_language import numpy as np -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.program import data_source from tensorflow_federated.python.program import serialization_utils -class ClientIdDataSourceIterator(data_source.FederatedDataSourceIterator): +class ClientIdDataSourceIterator( + federated_language.program.FederatedDataSourceIterator +): """A `tff.program.FederatedDataSourceIterator` backed by client ids. A `tff.program.FederatedDataSourceIterator` backed by sequence of client ids, @@ -48,8 +48,8 @@ def __init__(self, client_ids: Sequence[str]): raise ValueError('Expected `client_ids` to not be empty.') self._client_ids = client_ids - self._federated_type = computation_types.FederatedType( - np.str_, placements.CLIENTS + self._federated_type = federated_language.FederatedType( + np.str_, federated_language.CLIENTS ) @classmethod @@ -68,7 +68,7 @@ def to_bytes(self) -> bytes: return client_ids_bytes @property - def federated_type(self) -> computation_types.FederatedType: + def federated_type(self) -> federated_language.FederatedType: """The type of the data returned by calling `select`.""" return self._federated_type @@ -100,7 +100,7 @@ def __eq__(self, other: object) -> bool: return self._client_ids == other._client_ids -class ClientIdDataSource(data_source.FederatedDataSource): +class ClientIdDataSource(federated_language.program.FederatedDataSource): """A `tff.program.FederatedDataSource` backed by client ids.""" def __init__(self, client_ids: Sequence[str]): @@ -117,15 +117,15 @@ def __init__(self, client_ids: Sequence[str]): raise ValueError('Expected `client_ids` to not be empty.') self._client_ids = client_ids - self._federated_type = computation_types.FederatedType( - np.str_, placements.CLIENTS + self._federated_type = federated_language.FederatedType( + np.str_, federated_language.CLIENTS ) @property - def federated_type(self) -> computation_types.FederatedType: + def federated_type(self) -> federated_language.FederatedType: """The type of the data returned by calling `select` on an iterator.""" return self._federated_type - def iterator(self) -> data_source.FederatedDataSourceIterator: + def iterator(self) -> federated_language.program.FederatedDataSourceIterator: """Returns a new iterator for retrieving client ids from this data source.""" return ClientIdDataSourceIterator(self._client_ids) diff --git a/tensorflow_federated/python/program/client_id_data_source_test.py b/tensorflow_federated/python/program/client_id_data_source_test.py index 852168ec18..8162e59c8a 100644 --- a/tensorflow_federated/python/program/client_id_data_source_test.py +++ b/tensorflow_federated/python/program/client_id_data_source_test.py @@ -14,10 +14,8 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.program import client_id_data_source @@ -28,8 +26,8 @@ def test_init_sets_federated_type(self): iterator = client_id_data_source.ClientIdDataSourceIterator(client_ids) - federated_type = computation_types.FederatedType( - np.str_, placements.CLIENTS + federated_type = federated_language.FederatedType( + np.str_, federated_language.CLIENTS ) self.assertEqual(iterator.federated_type, federated_type) @@ -88,8 +86,8 @@ def test_init_sets_federated_type(self): data_source = client_id_data_source.ClientIdDataSource(client_ids) - federated_type = computation_types.FederatedType( - np.str_, placements.CLIENTS + federated_type = federated_language.FederatedType( + np.str_, federated_language.CLIENTS ) self.assertEqual(data_source.federated_type, federated_type) diff --git a/tensorflow_federated/python/program/data_source.py b/tensorflow_federated/python/program/data_source.py deleted file mode 100644 index e199778653..0000000000 --- a/tensorflow_federated/python/program/data_source.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines abstract interfaces for representing data sources.""" - -import abc -from typing import Optional - -from tensorflow_federated.python.common_libs import serializable -from tensorflow_federated.python.core.impl.types import computation_types - - -class FederatedDataSourceIterator(serializable.Serializable, abc.ABC): - """An abstract interface for representing federated data source iterators. - - This interface abstracts away the specifics of iterating over data in a data - source. - - Things one can do with a data source iterator: - - * Determine the type of the data supplied by this iterator by inspecting - the `federated_type` property. The type returned must match that of the data - source that returned this iterator. - - * Return a new selection of federated data from the iterator by invoking - `select`. - - Please see `tff.program.FederatedDataSource` for additional context and the - high-level description of how to use data sources. - """ - - @property - @abc.abstractmethod - def federated_type(self) -> computation_types.FederatedType: - """The type of the data returned by calling `select`.""" - raise NotImplementedError - - @abc.abstractmethod - def select(self, k: Optional[int] = None) -> object: - """Returns a new selection of federated data from this iterator. - - Args: - k: An optional number of elements to select. Must be a positive integer, - or `None` if unspecified. - - Returns: - An object of type `federated_type` representing the selected data, and - that can be supplied as an argument to a `tff.Computation`. See - `tff.program.FederatedContext` for more information about these types. - """ - raise NotImplementedError - - -class FederatedDataSource(abc.ABC): - """An abstract interface for representing federated data sources. - - This interface abstracts away the specifics of working with various types of - data sources. - - Things one can do with a data source: - - * Determine the type of the data supplied by this data source by inspecting - the `federated_type` property. The type returned should be a federated type. - Note that depending on whether this data source contains one or a number of - federated datasets, the type may or may not be a struct (with individual - datasets appearing as elements of this struct). - - * Construct a new iterator for this data source by invoking `iterator` on it. - Each iterator represents an independent pass over the data from this data - source. - """ - - @property - @abc.abstractmethod - def federated_type(self) -> computation_types.FederatedType: - """The type of the data returned by calling `select` on an iterator.""" - raise NotImplementedError - - @abc.abstractmethod - def iterator(self) -> FederatedDataSourceIterator: - """Returns a new iterator for retrieving data from this data source.""" - raise NotImplementedError diff --git a/tensorflow_federated/python/program/dataset_data_source.py b/tensorflow_federated/python/program/dataset_data_source.py index d75956744d..91da5e818f 100644 --- a/tensorflow_federated/python/program/dataset_data_source.py +++ b/tensorflow_federated/python/program/dataset_data_source.py @@ -17,16 +17,16 @@ import random from typing import Optional +import federated_language import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.program import data_source from tensorflow_federated.python.program import serialization_utils -class DatasetDataSourceIterator(data_source.FederatedDataSourceIterator): +class DatasetDataSourceIterator( + federated_language.program.FederatedDataSourceIterator +): """A `tff.program.FederatedDataSourceIterator` backed by `tf.data.Dataset`s. A `tff.program.FederatedDataSourceIterator` backed by a sequence of @@ -38,7 +38,7 @@ class DatasetDataSourceIterator(data_source.FederatedDataSourceIterator): def __init__( self, datasets: Sequence[tf.data.Dataset], - federated_type: computation_types.FederatedType, + federated_type: federated_language.FederatedType, ): """Returns an initialized `tff.program.DatasetDataSourceIterator`. @@ -77,7 +77,7 @@ def from_bytes(cls, buffer: bytes) -> 'DatasetDataSourceIterator': federated_type, _ = serialization_utils.unpack_type_spec_from( buffer, offset=offset ) - if not isinstance(federated_type, computation_types.FederatedType): + if not isinstance(federated_type, federated_language.FederatedType): raise TypeError( 'Expected `federated_type` to be a `tff.FederatedType`, found ' f'{type(federated_type)}.' @@ -97,7 +97,7 @@ def to_bytes(self) -> bytes: return datasets_bytes + federated_type_bytes @property - def federated_type(self) -> computation_types.FederatedType: + def federated_type(self) -> federated_language.FederatedType: """The type of the data returned by calling `select`.""" return self._federated_type @@ -134,7 +134,7 @@ def __eq__(self, other: object) -> bool: return True -class DatasetDataSource(data_source.FederatedDataSource): +class DatasetDataSource(federated_language.program.FederatedDataSource): """A `tff.program.FederatedDataSource` backed by `tf.data.Dataset`s. A `tff.program.FederatedDataSource` backed by a sequence of @@ -168,12 +168,13 @@ def __init__(self, datasets: Sequence[tf.data.Dataset]): self._datasets = datasets element_type = tensorflow_types.to_type(element_spec) - self._federated_type = computation_types.FederatedType( - computation_types.SequenceType(element_type), placements.CLIENTS + self._federated_type = federated_language.FederatedType( + federated_language.SequenceType(element_type), + federated_language.CLIENTS, ) @property - def federated_type(self) -> computation_types.FederatedType: + def federated_type(self) -> federated_language.FederatedType: """The type of the data returned by calling `select` on an iterator.""" return self._federated_type diff --git a/tensorflow_federated/python/program/dataset_data_source_test.py b/tensorflow_federated/python/program/dataset_data_source_test.py index fcf3906d91..ae12283812 100644 --- a/tensorflow_federated/python/program/dataset_data_source_test.py +++ b/tensorflow_federated/python/program/dataset_data_source_test.py @@ -14,11 +14,9 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.program import dataset_data_source @@ -26,8 +24,8 @@ class DatasetDataSourceIteratorTest(parameterized.TestCase): def test_init_does_not_raise_type_error(self): datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3 - federated_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) try: @@ -37,8 +35,8 @@ def test_init_does_not_raise_type_error(self): def test_init_raises_value_error_with_datasets_empty(self): datasets = [] - federated_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) with self.assertRaises(ValueError): @@ -49,8 +47,8 @@ def test_init_raises_value_error_with_datasets_different_types(self): tf.data.Dataset.from_tensor_slices([1, 2, 3]), tf.data.Dataset.from_tensor_slices(['a', 'b', 'c']), ] - federated_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) with self.assertRaises(ValueError): @@ -63,8 +61,8 @@ def test_init_raises_value_error_with_datasets_different_types(self): ) def test_select_returns_datasets_with_k(self, k): datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3 - federated_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) iterator = dataset_data_source.DatasetDataSourceIterator( datasets, federated_type @@ -84,8 +82,8 @@ def test_select_returns_datasets_with_k(self, k): ) def test_select_raises_value_error_with_k(self, k): datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3 - federated_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) iterator = dataset_data_source.DatasetDataSourceIterator( datasets, federated_type @@ -96,8 +94,8 @@ def test_select_raises_value_error_with_k(self, k): def test_serializable(self): datasets = [tf.data.Dataset.from_tensor_slices([1, 2, 3])] * 3 - federated_type = computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(np.int32), federated_language.CLIENTS ) iterator = dataset_data_source.DatasetDataSourceIterator( datasets, federated_type @@ -123,8 +121,8 @@ def test_init_sets_federated_type(self, tensors, dtype): data_source = dataset_data_source.DatasetDataSource(datasets) - federated_type = computation_types.FederatedType( - computation_types.SequenceType(dtype), placements.CLIENTS + federated_type = federated_language.FederatedType( + federated_language.SequenceType(dtype), federated_language.CLIENTS ) self.assertEqual(data_source.federated_type, federated_type) diff --git a/tensorflow_federated/python/program/federated_context.py b/tensorflow_federated/python/program/federated_context.py deleted file mode 100644 index 6112976986..0000000000 --- a/tensorflow_federated/python/program/federated_context.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2022, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines an abstract interface for representing a federated context.""" - -import abc -from typing import Optional, Union - -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import get_context_stack -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis -from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference - - -ComputationArg = Union[ - value_reference.MaterializableStructure, - object, - computation_base.Computation, -] - - -def contains_only_server_placed_data( - type_signature: computation_types.Type, -) -> bool: - """Determines if `type_signature` contains only server-placed data. - - Determines if `type_signature` contains only: - * `tff.StructType`s - * `tff.SequenceType`s - * server-placed `tff.FederatedType`s - * `tff.TensorType`s - - Args: - type_signature: The `tff.Type` to test. - - Returns: - `True` if `type_signature` contains only server-placed data, otherwise - `False`. - """ - - def _predicate(type_spec: computation_types.Type) -> bool: - return isinstance( - type_spec, - ( - computation_types.StructType, - computation_types.SequenceType, - computation_types.TensorType, - ), - ) or ( - isinstance(type_spec, computation_types.FederatedType) - and type_spec.placement is placements.SERVER - ) - - return type_analysis.contains_only(type_signature, _predicate) - - -class FederatedContext(context_base.SyncContext): - """An abstract interface representing a federated context. - - A federated context supports invoking a limited set of `tff.Computation`s, - making guarantees about what a `tff.Computation` can accept as an argument and - what it returns when invoked. - - ## Restrictions on the TensorFlow Federated Type - - Arguments can be nested structures of values corresponding to the TensorFlow - Federated type signature of the `tff.Computation`: - - * Server-placed values must be represented by - `tff.program.MaterializableStructure`. - * Client-placed values must be represented by structures of values returned - by a `tff.program.FederatedDataSourceIterator`. - - Return values can be structures of `tff.program.MaterializableValueReference`s - or a single `tff.program.MaterializableValueReference`, where a reference - corresponds to the tensor-type of the TensorFlow Federated type signature in - the return value of the invoked `tff.Computation`. - - ## TensorFlow Federated Type to Python Representation - - In order to interact with the value returned by a `tff.Computation`, it is - helpful to be able to reason about the Python type of this value. In some way - this Python type must depend on the TensorFlow Federated type signature of the - associated value. To provide uniformity of experience and ease of reasoning, - we specify the Python representation of values in a manner that can be stated - entirely in the TensorFlow Federated typesystem. - - We have chosen to limit the TensorFlow Federated type signatures of invoked - `tff.Computation`s to disallow the returning of client-placed values, - `tff.SequenceTypes`, and `tff.FunctionTypes`, in order to reduced the area - which needs to be supported by federated programs. Below we describe the - mapping between TensorFlow Federated type signatures and Python - representations of values that can be passed as arguments to or returned as - results from `tff.Computation`s. - - Python representations of values that can be *accepted as an arguments to* or - *returned as a value from* a `tff.Computation`: - - | TensorFlow Federated Type | Python Representation | - | -------------------------- | ------------------------------------------ | - | `tff.TensorType` | `tff.program.MaterializableValueReference` | - | `tff.SequenceType` | `tff.program.MaterializableValueReference` | - | `tff.FederatedType` | Python representation of the `member` of | - : (server-placed) : the `tff.FederatedType` : - | `tff.StructWithPythonType` | Python container of the | - : : `tff.StructWithPythonType` : - | `tff.StructType` (with no | `collections.OrderedDict` | - : Python type, all fields : : - : named) : : - | `tff.StructType` (with no | `tuple` | - : Python type, no fields : : - : named) : : - - Python representations of values that can be only be *accepted as an arguments - to* a `tff.Computation`: - - | TFF Type | Python Representation | - | ------------------- | --------------------------------------- | - | `tff.FederatedType` | Opaque object returned by | - : (client-placed) : `tff.program.DataSourceIterator.select` : - | `tff.FunctionType` | `tff.Computation` | - """ - - @abc.abstractmethod - def invoke( - self, - comp: computation_base.Computation, - arg: Optional[ComputationArg], - ) -> structure_utils.Structure[value_reference.MaterializableValueReference]: - """Invokes the `comp` with the argument `arg`. - - Args: - comp: The `tff.Computation` being invoked. - arg: The optional argument of `comp`; server-placed values must be - represented by `tff.program.MaterializableStructure`, and client-placed - values must be represented by structures of values returned by a - `tff.program.FederatedDataSourceIterator`. - - Returns: - The result of invocation; a structure of - `tff.program.MaterializableValueReference`. - - Raises: - ValueError: If the result type of `comp` does not contain only structures, - server-placed values, or tensors. - """ - raise NotImplementedError - - -def check_in_federated_context() -> None: - """Checks if the current context is a `tff.program.FederatedContext`.""" - context_stack = get_context_stack.get_context_stack() - if not isinstance(context_stack.current, FederatedContext): - raise ValueError( - 'Expected the current context to be a `tff.program.FederatedContext`, ' - f'found {type(context_stack.current)}.' - ) diff --git a/tensorflow_federated/python/program/federated_context_test.py b/tensorflow_federated/python/program/federated_context_test.py deleted file mode 100644 index e69d4c292b..0000000000 --- a/tensorflow_federated/python/program/federated_context_test.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright 2022, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np - -from tensorflow_federated.python.core.impl.context_stack import context_base -from tensorflow_federated.python.core.impl.context_stack import context_stack_impl -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.program import federated_context - - -class TestContext(context_base.SyncContext): - - def invoke(self, comp, arg): - return None - - -class ContainsOnlyServerPlacedDataTest(parameterized.TestCase): - - @parameterized.named_parameters( - ( - 'struct_unnamed', - computation_types.StructWithPythonType( - [np.bool_, np.int32, np.str_], list - ), - ), - ( - 'struct_named', - computation_types.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ('c', np.str_), - ], - collections.OrderedDict, - ), - ), - ( - 'struct_nested', - computation_types.StructWithPythonType( - [ - ( - 'x', - computation_types.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ], - collections.OrderedDict, - ), - ), - ( - 'y', - computation_types.StructWithPythonType( - [ - ('c', np.str_), - ], - collections.OrderedDict, - ), - ), - ], - collections.OrderedDict, - ), - ), - ( - 'federated_struct', - computation_types.FederatedType( - computation_types.StructWithPythonType( - [ - ('a', np.bool_), - ('b', np.int32), - ('c', np.str_), - ], - collections.OrderedDict, - ), - placements.SERVER, - ), - ), - ( - 'federated_sequence', - computation_types.FederatedType( - computation_types.SequenceType(np.int32), placements.SERVER - ), - ), - ( - 'federated_tensor', - computation_types.FederatedType(np.int32, placements.SERVER), - ), - ('sequence', computation_types.SequenceType(np.int32)), - ('tensor', computation_types.TensorType(np.int32)), - ) - def test_returns_true(self, type_signature): - result = federated_context.contains_only_server_placed_data(type_signature) - - self.assertTrue(result) - - @parameterized.named_parameters( - ( - 'federated_tensor', - computation_types.FederatedType(np.int32, placements.CLIENTS), - ), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('placement', computation_types.PlacementType()), - ) - def test_returns_false(self, type_signature): - result = federated_context.contains_only_server_placed_data(type_signature) - - self.assertFalse(result) - - -class CheckInFederatedContextTest(parameterized.TestCase): - - def test_does_not_raise_value_error_with_context(self): - context = mock.create_autospec( - federated_context.FederatedContext, spec_set=True, instance=True - ) - - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - with context_stack_impl.context_stack.install(context): - try: - federated_context.check_in_federated_context() - except TypeError: - self.fail('Raised `ValueError` unexpectedly.') - - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - def test_raises_value_error_with_context(self): - context = TestContext() - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - with context_stack_impl.context_stack.install(context): - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - def test_raises_value_error_with_context_nested(self): - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - context = mock.create_autospec( - federated_context.FederatedContext, spec_set=True, instance=True - ) - with context_stack_impl.context_stack.install(context): - try: - federated_context.check_in_federated_context() - except TypeError: - self.fail('Raised `ValueError` unexpectedly.') - - context = TestContext() - with context_stack_impl.context_stack.install(context): - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - with self.assertRaises(ValueError): - federated_context.check_in_federated_context() - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/program/file_program_state_manager.py b/tensorflow_federated/python/program/file_program_state_manager.py index 91d17855d5..e65930bd46 100644 --- a/tensorflow_federated/python/program/file_program_state_manager.py +++ b/tensorflow_federated/python/program/file_program_state_manager.py @@ -26,18 +26,16 @@ from typing import Optional, Union from absl import logging +import federated_language import tensorflow as tf -from tensorflow_federated.python.common_libs import serializable from tensorflow_federated.python.program import file_utils -from tensorflow_federated.python.program import program_state_manager from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference class FileProgramStateManager( - program_state_manager.ProgramStateManager[ - program_state_manager.ProgramStateStructure + federated_language.program.ProgramStateManager[ + federated_language.program.ProgramStateStructure ] ): """A `tff.program.ProgramStateManager` that is backed by a file system. @@ -160,8 +158,10 @@ def _get_path_for_version(self, version: int) -> str: return os.path.join(self._root_dir, basename) async def load( - self, version: int, structure: program_state_manager.ProgramStateStructure - ) -> program_state_manager.ProgramStateStructure: + self, + version: int, + structure: federated_language.program.ProgramStateStructure, + ) -> federated_language.program.ProgramStateStructure: """Returns the program state for the given `version`. Args: @@ -176,12 +176,12 @@ async def load( """ path = self._get_path_for_version(version) if not await file_utils.exists(path): - raise program_state_manager.ProgramStateNotFoundError(version) + raise federated_language.program.ProgramStateNotFoundError(version) program_state = await file_utils.read_saved_model(path) def _normalize( - value: program_state_manager.ProgramStateValue, - ) -> program_state_manager.ProgramStateValue: + value: federated_language.program.ProgramStateValue, + ) -> federated_language.program.ProgramStateValue: """Returns a normalized value. The `tff.program.FileProgramStateManager` saves and loads program state to @@ -201,7 +201,7 @@ def _normalize( normalized_state = structure_utils.map_structure(_normalize, program_state) def _deserialize_as(structure, value): - if isinstance(structure, serializable.Serializable): + if isinstance(structure, federated_language.Serializable): serializable_cls = type(structure) value = serializable_cls.from_bytes(value) return value @@ -255,7 +255,7 @@ async def remove_all(self) -> None: async def save( self, - program_state: program_state_manager.ProgramStateStructure, + program_state: federated_language.program.ProgramStateStructure, version: int, ) -> None: """Saves `program_state` for the given `version`. @@ -271,13 +271,15 @@ async def save( """ path = self._get_path_for_version(version) if await file_utils.exists(path): - raise program_state_manager.ProgramStateExistsError( + raise federated_language.program.ProgramStateExistsError( version=version, path=self._root_dir ) - materialized_state = await value_reference.materialize_value(program_state) + materialized_state = await federated_language.program.materialize_value( + program_state + ) def _serialize(value): - if isinstance(value, serializable.Serializable): + if isinstance(value, federated_language.Serializable): value = value.to_bytes() return value diff --git a/tensorflow_federated/python/program/file_program_state_manager_test.py b/tensorflow_federated/python/program/file_program_state_manager_test.py index 732d270554..4da7386c76 100644 --- a/tensorflow_federated/python/program/file_program_state_manager_test.py +++ b/tensorflow_federated/python/program/file_program_state_manager_test.py @@ -21,12 +21,12 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tree from tensorflow_federated.python.program import file_program_state_manager from tensorflow_federated.python.program import file_utils -from tensorflow_federated.python.program import program_state_manager from tensorflow_federated.python.program import program_test_utils @@ -414,7 +414,9 @@ async def test_raises_program_state_not_found_error_with_no_program_state( version = 0 structure = 'state' - with self.assertRaises(program_state_manager.ProgramStateNotFoundError): + with self.assertRaises( + federated_language.program.ProgramStateNotFoundError + ): await program_state_mngr.load(version, structure) async def test_raises_program_state_not_found_error_with_unknown_version( @@ -430,7 +432,9 @@ async def test_raises_program_state_not_found_error_with_unknown_version( unknown_version = 0 structure = 'state' - with self.assertRaises(program_state_manager.ProgramStateNotFoundError): + with self.assertRaises( + federated_language.program.ProgramStateNotFoundError + ): await program_state_mngr.load(unknown_version, structure) @@ -878,7 +882,7 @@ async def test_raises_program_state_exists_error_with_existing_version(self): await program_state_mngr.save(program_state, version) - with self.assertRaises(program_state_manager.ProgramStateExistsError): + with self.assertRaises(federated_language.program.ProgramStateExistsError): await program_state_mngr.save(program_state, version) diff --git a/tensorflow_federated/python/program/file_release_manager.py b/tensorflow_federated/python/program/file_release_manager.py index 7158b8ec2c..727f396ad9 100644 --- a/tensorflow_federated/python/program/file_release_manager.py +++ b/tensorflow_federated/python/program/file_release_manager.py @@ -30,13 +30,12 @@ import random from typing import Union +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.program import file_utils -from tensorflow_federated.python.program import release_manager from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference class CSVKeyFieldnameNotFoundError(Exception): @@ -62,7 +61,9 @@ class CSVSaveMode(enum.Enum): class CSVFileReleaseManager( - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ): """A `tff.program.ReleaseManager` that releases values to a CSV file. @@ -155,7 +156,9 @@ def _read_values(self) -> tuple[list[str], list[dict[str, str]]]: def _write_values( self, fieldnames: Sequence[str], - values: Iterable[Mapping[str, release_manager.ReleasableStructure]], + values: Iterable[ + Mapping[str, federated_language.program.ReleasableStructure] + ], ) -> None: """Writes `fieldnames` and `values` to the managed CSV.""" path = os.fspath(self._file_path) @@ -175,7 +178,7 @@ def _write_values( tf.io.gfile.rename(temp_path, self._file_path, overwrite=True) async def _write_value( - self, value: Mapping[str, release_manager.ReleasableStructure] + self, value: Mapping[str, federated_language.program.ReleasableStructure] ) -> None: """Writes `value` to the managed CSV.""" loop = asyncio.get_running_loop() @@ -185,7 +188,7 @@ async def _write_value( await loop.run_in_executor(None, self._write_values, fieldnames, values) async def _append_value( - self, value: Mapping[str, release_manager.ReleasableStructure] + self, value: Mapping[str, federated_language.program.ReleasableStructure] ) -> None: """Appends `value` to the managed CSV.""" @@ -200,7 +203,7 @@ def _read_fieldnames_only() -> list[str]: def _append_value( fieldnames: Sequence[str], - value: Mapping[str, release_manager.ReleasableStructure], + value: Mapping[str, federated_language.program.ReleasableStructure], ) -> None: try: with tf.io.gfile.GFile(self._file_path, 'a') as file: @@ -249,7 +252,7 @@ async def _remove_values_greater_than_key(self, key: int) -> None: self._latest_key = key async def release( - self, value: release_manager.ReleasableStructure, key: int + self, value: federated_language.program.ReleasableStructure, key: int ) -> None: """Releases `value` from a federated program. @@ -264,14 +267,14 @@ async def release( """ _, materialized_value = await asyncio.gather( self._remove_values_greater_than_key(key - 1), - value_reference.materialize_value(value), + federated_language.program.materialize_value(value), ) flattened_value = structure_utils.flatten_with_name(materialized_value) def _normalize( - value: value_reference.MaterializedValue, - ) -> value_reference.MaterializedValue: + value: federated_language.program.MaterializedValue, + ) -> federated_language.program.MaterializedValue: if isinstance(value, tf.data.Dataset): value = list(value) return np.array(value).tolist() @@ -287,8 +290,9 @@ def _normalize( class SavedModelFileReleaseManager( - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, release_manager.Key + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, + federated_language.program.Key, ] ): """A `tff.program.ReleaseManager` that releases values to a file system. @@ -331,7 +335,7 @@ def __init__( self._root_dir = root_dir self._prefix = prefix - def _get_path_for_key(self, key: release_manager.Key) -> str: + def _get_path_for_key(self, key: federated_language.program.Key) -> str: """Returns the path for the given `key`. This method does not assert that the given `key` or the returned path @@ -344,7 +348,9 @@ def _get_path_for_key(self, key: release_manager.Key) -> str: return os.path.join(self._root_dir, basename) async def release( - self, value: release_manager.ReleasableStructure, key: release_manager.Key + self, + value: federated_language.program.ReleasableStructure, + key: federated_language.program.Key, ) -> None: """Releases `value` from a federated program. @@ -353,13 +359,15 @@ async def release( key: Used to reference (in the file system) the released `value`. """ path = self._get_path_for_key(key) - materialized_value = await value_reference.materialize_value(value) + materialized_value = await federated_language.program.materialize_value( + value + ) await file_utils.write_saved_model(materialized_value, path, overwrite=True) async def get_value( self, - key: release_manager.Key, - ) -> release_manager.ReleasableStructure: + key: federated_language.program.Key, + ) -> federated_language.program.ReleasableStructure: """Returns the value for the given `key`. The SavedModel format flattens and deterministicly orders keys. This @@ -379,12 +387,12 @@ async def get_value( path = self._get_path_for_key(key) if not await file_utils.exists(path): - raise release_manager.ReleasedValueNotFoundError(key) + raise federated_language.program.ReleasedValueNotFoundError(key) value = await file_utils.read_saved_model(path) def _normalize( - value: release_manager.ReleasableValue, - ) -> release_manager.ReleasableValue: + value: federated_language.program.ReleasableValue, + ) -> federated_language.program.ReleasableValue: """Returns a normalized value. The `tff.program.SavedModelFileReleaseManager` releases and gets values diff --git a/tensorflow_federated/python/program/file_release_manager_test.py b/tensorflow_federated/python/program/file_release_manager_test.py index ae22b7bbc3..3ac8be6174 100644 --- a/tensorflow_federated/python/program/file_release_manager_test.py +++ b/tensorflow_federated/python/program/file_release_manager_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf import tree @@ -30,13 +31,14 @@ from tensorflow_federated.python.program import file_release_manager from tensorflow_federated.python.program import file_utils from tensorflow_federated.python.program import program_test_utils -from tensorflow_federated.python.program import release_manager from tensorflow_federated.python.program import structure_utils def _read_values_from_csv( - file_path: Union[str, os.PathLike[str]] -) -> tuple[list[str], list[dict[str, release_manager.ReleasableStructure]]]: + file_path: Union[str, os.PathLike[str]], +) -> tuple[ + list[str], list[dict[str, federated_language.program.ReleasableStructure]] +]: with tf.io.gfile.GFile(file_path, 'r') as file: reader = csv.DictReader(file) fieldnames = list(reader.fieldnames) @@ -47,7 +49,9 @@ def _read_values_from_csv( def _write_values_to_csv( file_path: Union[str, os.PathLike[str]], fieldnames: Sequence[str], - values: Iterable[Mapping[str, release_manager.ReleasableStructure]], + values: Iterable[ + Mapping[str, federated_language.program.ReleasableStructure] + ], ) -> None: with tf.io.gfile.GFile(file_path, 'w') as file: writer = csv.DictWriter(file, fieldnames) @@ -1135,7 +1139,9 @@ async def test_raises_released_value_not_found_error_with_no_saved_value( release_mngr = file_release_manager.SavedModelFileReleaseManager(root_dir) key = 1 - with self.assertRaises(release_manager.ReleasedValueNotFoundError): + with self.assertRaises( + federated_language.program.ReleasedValueNotFoundError + ): await release_mngr.get_value(key) async def test_raises_released_value_not_found_error_with_unknown_key(self): @@ -1146,7 +1152,9 @@ async def test_raises_released_value_not_found_error_with_unknown_key(self): await release_mngr.release(value, key=key) unknown_key = 10 - with self.assertRaises(release_manager.ReleasedValueNotFoundError): + with self.assertRaises( + federated_language.program.ReleasedValueNotFoundError + ): await release_mngr.get_value(unknown_key) diff --git a/tensorflow_federated/python/program/logging_release_manager.py b/tensorflow_federated/python/program/logging_release_manager.py deleted file mode 100644 index ab8d48d639..0000000000 --- a/tensorflow_federated/python/program/logging_release_manager.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for releasing values from a federated program to logs.""" - -from typing import Optional - -from absl import logging - -from tensorflow_federated.python.program import release_manager -from tensorflow_federated.python.program import value_reference - - -class LoggingReleaseManager( - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, release_manager.Key - ] -): - """A `tff.program.ReleaseManager` that releases values to logs. - - A `tff.program.LoggingReleaseManager` is a utility for releasing values from a - federated program to logs and is used to release values from platform storage - to customer storage in a federated program. - - Values are released to logs as string representations of Python objects. When - the value is released, if the value is a value reference or a structure - containing value references, each value reference is materialized. - """ - - async def release( - self, - value: release_manager.ReleasableStructure, - key: Optional[release_manager.Key], - ) -> None: - """Releases `value` from a federated program. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: An optional value used to reference the released `value`. - """ - materialized_value = await value_reference.materialize_value(value) - logging.info('Releasing') - logging.info(' value: %s', materialized_value) - if key is not None: - logging.info(' key: %s', key) diff --git a/tensorflow_federated/python/program/logging_release_manager_test.py b/tensorflow_federated/python/program/logging_release_manager_test.py deleted file mode 100644 index 21cb4b1bf7..0000000000 --- a/tensorflow_federated/python/program/logging_release_manager_test.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -import tree - -from tensorflow_federated.python.program import logging_release_manager -from tensorflow_federated.python.program import program_test_utils - - -class LoggingReleaseManagerTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - # materialized values - ('none', None, None), - ('bool', True, True), - ('int', 1, 1), - ('str', 'abc', 'abc'), - ('numpy_int', np.int32(1), np.int32(1)), - ('numpy_array', np.array([1] * 3, np.int32), np.array([1] * 3, np.int32)), - # materializable value references - ( - 'value_reference_tensor', - program_test_utils.TestMaterializableValueReference(1), - 1, - ), - ( - 'value_reference_sequence', - program_test_utils.TestMaterializableValueReference([1, 2, 3]), - [1, 2, 3], - ), - # serializable values - ( - 'serializable_value', - program_test_utils.TestSerializable(1, 2), - program_test_utils.TestSerializable(1, 2), - ), - # other values - ( - 'attrs', - program_test_utils.TestAttrs(1, 2), - program_test_utils.TestAttrs(1, 2), - ), - # structures - ( - 'list', - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [ - True, - 1, - 'abc', - 2, - program_test_utils.TestSerializable(3, 4), - ], - ), - ('list_empty', [], []), - ( - 'list_nested', - [ - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - [ - [ - True, - 1, - 'abc', - 2, - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - ), - ( - 'dict_ordered', - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ( - 'dict_unordered', - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ('dict_empty', {}, {}), - ( - 'dict_nested', - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - ), - ( - 'named_tuple', - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=2, - e=program_test_utils.TestSerializable(3, 4), - ), - ), - ( - 'named_tuple_nested', - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=2, - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - ), - ) - async def test_release_logs_value(self, value, expected_value): - release_mngr = logging_release_manager.LoggingReleaseManager() - key = 1 - - with mock.patch('absl.logging.info') as mock_info: - await release_mngr.release(value, key=key) - - self.assertLen(mock_info.mock_calls, 3) - mock_info.assert_has_calls([ - mock.call(mock.ANY), - mock.call(mock.ANY, mock.ANY), - mock.call(mock.ANY, key), - ]) - call = mock_info.mock_calls[1] - _, args, kwargs = call - _, actual_value = args - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - self.assertEqual(kwargs, {}) - - @parameterized.named_parameters( - ('bool', True), - ('int', 1), - ('str', 'abc'), - ('list', []), - ) - async def test_release_logs_key(self, key): - release_mngr = logging_release_manager.LoggingReleaseManager() - value = 1 - - with mock.patch('absl.logging.info') as mock_info: - await release_mngr.release(value, key=key) - - self.assertLen(mock_info.mock_calls, 3) - mock_info.assert_has_calls([ - mock.call(mock.ANY), - mock.call(mock.ANY, value), - mock.call(mock.ANY, key), - ]) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/program/memory_release_manager.py b/tensorflow_federated/python/program/memory_release_manager.py deleted file mode 100644 index 514e4099a4..0000000000 --- a/tensorflow_federated/python/program/memory_release_manager.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for releasing values from a federated program to memory.""" - -import collections -from collections.abc import Hashable - -from tensorflow_federated.python.program import release_manager -from tensorflow_federated.python.program import value_reference - - -class MemoryReleaseManager( - release_manager.ReleaseManager[ - release_manager.ReleasableStructure, Hashable - ] -): - """A `tff.program.ReleaseManager` that releases values to memory. - - A `tff.program.MemoryReleaseManager` is a utility for releasing values from a - federated program to memory and is used to release values from platform - storage to customer storage in a federated program. - - Values are released to memory as Python objects. When the value is released, - if the value is a value reference or a structure containing value references, - each value reference is materialized. - """ - - def __init__(self): - """Returns an initialized `tff.program.MemoryReleaseManager`.""" - self._values = collections.OrderedDict() - - async def release( - self, value: release_manager.ReleasableStructure, key: Hashable - ) -> None: - """Releases `value` from a federated program. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: A hashable value used to reference the released `value`. - """ - materialized_value = await value_reference.materialize_value(value) - self._values[key] = materialized_value - - def remove_all(self) -> None: - """Removes all program states.""" - self._values = collections.OrderedDict() - - def values( - self, - ) -> collections.OrderedDict[Hashable, release_manager.ReleasableStructure]: - """Returns an `collections.OrderedDict` of all keys and released values.""" - return self._values.copy() diff --git a/tensorflow_federated/python/program/memory_release_manager_test.py b/tensorflow_federated/python/program/memory_release_manager_test.py deleted file mode 100644 index a293a2c7e6..0000000000 --- a/tensorflow_federated/python/program/memory_release_manager_test.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import unittest - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -import tree - -from tensorflow_federated.python.program import memory_release_manager -from tensorflow_federated.python.program import program_test_utils - - -class MemoryReleaseManagerTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - # materialized values - ('none', None, None), - ('bool', True, True), - ('int', 1, 1), - ('str', 'abc', 'abc'), - ('numpy_int', np.int32(1), np.int32(1)), - ('numpy_array', np.array([1] * 3, np.int32), np.array([1] * 3, np.int32)), - # materializable value references - ( - 'value_reference_tensor', - program_test_utils.TestMaterializableValueReference(1), - 1, - ), - ( - 'value_reference_sequence', - program_test_utils.TestMaterializableValueReference([1, 2, 3]), - [1, 2, 3], - ), - # serializable values - ( - 'serializable_value', - program_test_utils.TestSerializable(1, 2), - program_test_utils.TestSerializable(1, 2), - ), - # other values - ( - 'attrs', - program_test_utils.TestAttrs(1, 2), - program_test_utils.TestAttrs(1, 2), - ), - # structures - ( - 'list', - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [ - True, - 1, - 'abc', - 2, - program_test_utils.TestSerializable(3, 4), - ], - ), - ('list_empty', [], []), - ( - 'list_nested', - [ - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - [ - [ - True, - 1, - 'abc', - 2, - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - ), - ( - 'dict_ordered', - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ( - 'dict_unordered', - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ('dict_empty', {}, {}), - ( - 'dict_nested', - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - ), - ( - 'named_tuple', - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=2, - e=program_test_utils.TestSerializable(3, 4), - ), - ), - ( - 'named_tuple_nested', - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=2, - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - ), - ) - async def test_release_saves_value(self, value, expected_value): - release_mngr = memory_release_manager.MemoryReleaseManager() - key = 1 - - await release_mngr.release(value, key=key) - - self.assertLen(release_mngr._values, 1) - actual_value = release_mngr._values[1] - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - - async def test_remove_all_with_no_values(self): - release_mngr = memory_release_manager.MemoryReleaseManager() - - release_mngr.remove_all() - - self.assertEqual(release_mngr._values, {}) - - async def test_remove_all_with_values(self): - release_mngr = memory_release_manager.MemoryReleaseManager() - release_mngr._values = collections.OrderedDict([(i, i) for i in range(3)]) - - release_mngr.remove_all() - - self.assertEqual(release_mngr._values, {}) - - @parameterized.named_parameters( - ('0', 0), - ('1', 1), - ('10', 10), - ) - def test_values_returns_values(self, count): - expected_values = collections.OrderedDict([(i, i) for i in range(count)]) - release_mngr = memory_release_manager.MemoryReleaseManager() - release_mngr._values = expected_values - - actual_values = release_mngr.values() - - self.assertEqual(actual_values, expected_values) - - def test_values_returns_copy(self): - release_mngr = memory_release_manager.MemoryReleaseManager() - - values_1 = release_mngr.values() - values_2 = release_mngr.values() - self.assertIsNot(values_1, values_2) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/program/native_platform.py b/tensorflow_federated/python/program/native_platform.py index c9afb2cac7..1dc6195d3a 100644 --- a/tensorflow_federated/python/program/native_platform.py +++ b/tensorflow_federated/python/program/native_platform.py @@ -17,26 +17,22 @@ from collections.abc import Mapping from typing import Optional, Union +import federated_language import tree from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_conversions -from tensorflow_federated.python.program import federated_context from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference -class NativeValueReference(value_reference.MaterializableValueReference): +class NativeValueReference( + federated_language.program.MaterializableValueReference +): """A `tff.program.MaterializableValueReference` backed by a task.""" def __init__( self, task: asyncio.Task, - type_signature: value_reference.MaterializableTypeSignature, + type_signature: federated_language.program.MaterializableTypeSignature, ): """Returns an initialized `tff.program.NativeValueReference`. @@ -48,11 +44,13 @@ def __init__( self._type_signature = type_signature @property - def type_signature(self) -> value_reference.MaterializableTypeSignature: + def type_signature( + self, + ) -> federated_language.program.MaterializableTypeSignature: """The `tff.TensorType` of this object.""" return self._type_signature - async def get_value(self) -> value_reference.MaterializedValue: + async def get_value(self) -> federated_language.program.MaterializedValue: """Returns the referenced value as a numpy scalar or array.""" return await self._task @@ -72,7 +70,7 @@ def __eq__(self, other: object) -> bool: def _create_structure_of_references( task: asyncio.Task, - type_signature: computation_types.Type, + type_signature: federated_language.Type, ) -> structure_utils.Structure[NativeValueReference]: """Returns a structure of `tff.program.NativeValueReference`s. @@ -85,10 +83,10 @@ def _create_structure_of_references( Raises: NotImplementedError: If `type_signature` contains an unexpected type. """ - if isinstance(type_signature, computation_types.StructType): + if isinstance(type_signature, federated_language.StructType): def _get_container_cls( - type_spec: computation_types.StructType, + type_spec: federated_language.StructType, ) -> type[object]: container_cls = type_spec.python_container if container_cls is None: @@ -106,7 +104,7 @@ def _get_container_cls( async def _get_item( task: asyncio.Task, key: Union[str, int] - ) -> value_reference.MaterializedValue: + ) -> federated_language.program.MaterializedValue: value = await task return value[key] @@ -121,24 +119,28 @@ async def _get_item( element_task = asyncio.create_task(element) element = _create_structure_of_references(element_task, element_type) elements.append(element) - return type_conversions.to_structure_with_type(elements, type_signature) + return federated_language.framework.to_structure_with_type( + elements, type_signature + ) elif ( - isinstance(type_signature, computation_types.FederatedType) - and type_signature.placement == placements.SERVER + isinstance(type_signature, federated_language.FederatedType) + and type_signature.placement == federated_language.SERVER ): return _create_structure_of_references(task, type_signature.member) - elif isinstance(type_signature, computation_types.SequenceType): + elif isinstance(type_signature, federated_language.SequenceType): return NativeValueReference(task, type_signature) - elif isinstance(type_signature, computation_types.TensorType): + elif isinstance(type_signature, federated_language.TensorType): return NativeValueReference(task, type_signature) else: raise NotImplementedError(f'Unexpected type found: {type_signature}.') -class NativeFederatedContext(federated_context.FederatedContext): +class NativeFederatedContext(federated_language.program.FederatedContext): """A `tff.program.FederatedContext` backed by an execution context.""" - def __init__(self, context: async_execution_context.AsyncExecutionContext): + def __init__( + self, context: federated_language.framework.AsyncExecutionContext + ): """Returns an initialized `tff.program.NativeFederatedContext`. Args: @@ -148,8 +150,8 @@ def __init__(self, context: async_execution_context.AsyncExecutionContext): def invoke( self, - comp: computation_base.Computation, - arg: Optional[federated_context.ComputationArg], + comp: federated_language.framework.Computation, + arg: Optional[federated_language.program.ComputationArg], ) -> structure_utils.Structure[NativeValueReference]: """Invokes the `comp` with the argument `arg`. @@ -172,17 +174,19 @@ def invoke( server-placed values, or tensors. """ result_type = comp.type_signature.result - if not federated_context.contains_only_server_placed_data(result_type): + if not federated_language.program.contains_only_server_placed_data( + result_type + ): raise ValueError( 'Expected the result type of `comp` to contain only structures, ' f'server-placed values, or tensors, found {result_type}.' ) async def _invoke( - context: async_execution_context.AsyncExecutionContext, - comp: computation_base.Computation, - arg: value_reference.MaterializableStructure, - ) -> value_reference.MaterializedStructure: + context: federated_language.framework.AsyncExecutionContext, + comp: federated_language.framework.Computation, + arg: federated_language.program.MaterializableStructure, + ) -> federated_language.program.MaterializedStructure: if comp.type_signature.parameter is not None: def _to_python(obj): @@ -192,7 +196,7 @@ def _to_python(obj): return None arg = tree.traverse(_to_python, arg) - arg = await value_reference.materialize_value(arg) + arg = await federated_language.program.materialize_value(arg) return await context.invoke(comp, arg) diff --git a/tensorflow_federated/python/program/native_platform_test.py b/tensorflow_federated/python/program/native_platform_test.py index b306c38409..2afaf21f26 100644 --- a/tensorflow_federated/python/program/native_platform_test.py +++ b/tensorflow_federated/python/program/native_platform_test.py @@ -18,20 +18,15 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tree from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.program import native_platform from tensorflow_federated.python.program import program_test_utils from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference def _create_task(value: object) -> object: @@ -44,9 +39,9 @@ async def _fn(value: object) -> object: def _create_identity_federated_computation( - type_signature: computation_types.Type, -) -> computation_base.Computation: - @federated_computation.federated_computation(type_signature) + type_signature: federated_language.Type, +) -> federated_language.framework.Computation: + @federated_language.federated_computation(type_signature) def _identity(value: object) -> object: return value @@ -54,8 +49,8 @@ def _identity(value: object) -> object: def _create_identity_tensorflow_computation( - type_signature: computation_types.Type, -) -> computation_base.Computation: + type_signature: federated_language.Type, +) -> federated_language.framework.Computation: @tensorflow_computation.tf_computation(type_signature) def _identity(value: object) -> object: return value @@ -71,25 +66,25 @@ class NativeValueReferenceTest( ( 'tensor_bool', lambda: _create_task(True), - computation_types.TensorType(np.bool_), + federated_language.TensorType(np.bool_), True, ), ( 'tensor_int', lambda: _create_task(1), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), 1, ), ( 'tensor_str', lambda: _create_task('abc'), - computation_types.TensorType(np.str_), + federated_language.TensorType(np.str_), 'abc', ), ( 'sequence', lambda: _create_task([1, 2, 3]), - computation_types.SequenceType(np.int32), + federated_language.SequenceType(np.int32), [1, 2, 3], ), ) @@ -115,50 +110,50 @@ class CreateStructureOfReferencesTest( ( 'tensor', lambda: _create_task(1), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), lambda: native_platform.NativeValueReference( - _create_task(1), computation_types.TensorType(np.int32) + _create_task(1), federated_language.TensorType(np.int32) ), ), ( 'sequence', lambda: _create_task([1, 2, 3]), - computation_types.SequenceType(np.int32), + federated_language.SequenceType(np.int32), lambda: native_platform.NativeValueReference( _create_task([1, 2, 3]), - computation_types.SequenceType(np.int32), + federated_language.SequenceType(np.int32), ), ), ( 'federated_server', lambda: _create_task(1), - computation_types.FederatedType(np.int32, placements.SERVER), + federated_language.FederatedType(np.int32, federated_language.SERVER), lambda: native_platform.NativeValueReference( - _create_task(1), computation_types.TensorType(np.int32) + _create_task(1), federated_language.TensorType(np.int32) ), ), ( 'struct_unnamed', lambda: _create_task([True, 1, 'abc']), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [np.bool_, np.int32, np.str_], list ), lambda: [ native_platform.NativeValueReference( - _create_task(True), computation_types.TensorType(np.bool_) + _create_task(True), federated_language.TensorType(np.bool_) ), native_platform.NativeValueReference( - _create_task(1), computation_types.TensorType(np.int32) + _create_task(1), federated_language.TensorType(np.int32) ), native_platform.NativeValueReference( - _create_task('abc'), computation_types.TensorType(np.str_) + _create_task('abc'), federated_language.TensorType(np.str_) ), ], ), ( 'struct_named', lambda: _create_task({'a': True, 'b': 1, 'c': 'abc'}), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -168,24 +163,24 @@ class CreateStructureOfReferencesTest( ), lambda: { 'a': native_platform.NativeValueReference( - _create_task(True), computation_types.TensorType(np.bool_) + _create_task(True), federated_language.TensorType(np.bool_) ), 'b': native_platform.NativeValueReference( - _create_task(1), computation_types.TensorType(np.int32) + _create_task(1), federated_language.TensorType(np.int32) ), 'c': native_platform.NativeValueReference( - _create_task('abc'), computation_types.TensorType(np.str_) + _create_task('abc'), federated_language.TensorType(np.str_) ), }, ), ( 'struct_nested', lambda: _create_task({'x': {'a': True, 'b': 1}, 'y': {'c': 'abc'}}), - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -195,7 +190,7 @@ class CreateStructureOfReferencesTest( ), ( 'y', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('c', np.str_), ], @@ -209,17 +204,17 @@ class CreateStructureOfReferencesTest( 'x': { 'a': native_platform.NativeValueReference( _create_task(True), - computation_types.TensorType(np.bool_), + federated_language.TensorType(np.bool_), ), 'b': native_platform.NativeValueReference( _create_task(1), - computation_types.TensorType(np.int32), + federated_language.TensorType(np.int32), ), }, 'y': { 'c': native_platform.NativeValueReference( _create_task('abc'), - computation_types.TensorType(np.str_), + federated_language.TensorType(np.str_), ), }, }, @@ -234,8 +229,12 @@ async def test_returns_value( ) expected_value = expected_value_factory() - actual_value = await value_reference.materialize_value(actual_value) - expected_value = await value_reference.materialize_value(expected_value) + actual_value = await federated_language.program.materialize_value( + actual_value + ) + expected_value = await federated_language.program.materialize_value( + expected_value + ) tree.assert_same_structure(actual_value, expected_value) program_test_utils.assert_same_key_order(actual_value, expected_value) actual_value = program_test_utils.to_python(actual_value) @@ -245,10 +244,12 @@ async def test_returns_value( @parameterized.named_parameters( ( 'federated_clients', - computation_types.FederatedType(np.int32, placements.CLIENTS), + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ), ), - ('function', computation_types.FunctionType(np.int32, np.int32)), - ('placement', computation_types.PlacementType()), + ('function', federated_language.FunctionType(np.int32, np.int32)), + ('placement', federated_language.PlacementType()), ) async def test_raises_not_implemented_error_with_type_signature( self, type_signature @@ -267,7 +268,7 @@ class NativeFederatedContextTest( ( 'tensor', _create_identity_federated_computation( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ), 1, 1, @@ -275,7 +276,7 @@ class NativeFederatedContextTest( ( 'sequence', _create_identity_tensorflow_computation( - computation_types.SequenceType(np.int32) + federated_language.SequenceType(np.int32) ), [1, 2, 3], [1, 2, 3], @@ -283,7 +284,9 @@ class NativeFederatedContextTest( ( 'federated_server', _create_identity_federated_computation( - computation_types.FederatedType(np.int32, placements.SERVER) + federated_language.FederatedType( + np.int32, federated_language.SERVER + ) ), 1, 1, @@ -291,7 +294,7 @@ class NativeFederatedContextTest( ( 'struct_unnamed', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [np.bool_, np.int32, np.str_], list ) ), @@ -301,7 +304,7 @@ class NativeFederatedContextTest( ( 'struct_named_ordered', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -316,7 +319,7 @@ class NativeFederatedContextTest( ( 'struct_named_unordered', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('c', np.str_), ('b', np.int32), @@ -335,7 +338,7 @@ async def test_invoke_returns_result(self, comp, arg, expected_value): with program_test_utils.assert_not_warns(RuntimeWarning): result = context.invoke(comp, arg) - actual_value = await value_reference.materialize_value(result) + actual_value = await federated_language.program.materialize_value(result) tree.assert_same_structure(actual_value, expected_value) program_test_utils.assert_same_key_order(actual_value, expected_value) @@ -347,11 +350,11 @@ async def test_invoke_returns_result(self, comp, arg, expected_value): ( 'struct_nested', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -361,7 +364,7 @@ async def test_invoke_returns_result(self, comp, arg, expected_value): ), ( 'y', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('c', np.str_), ], @@ -378,11 +381,11 @@ async def test_invoke_returns_result(self, comp, arg, expected_value): ( 'struct_partially_empty', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -390,7 +393,7 @@ async def test_invoke_returns_result(self, comp, arg, expected_value): dict, ), ), - ('y', computation_types.StructWithPythonType([], dict)), + ('y', federated_language.StructWithPythonType([], dict)), ], dict, ) @@ -404,7 +407,7 @@ async def test_invoke_returns_result_materialized_sequentially( ): context = execution_contexts.create_async_local_cpp_execution_context() mock_context = mock.Mock( - spec=async_execution_context.AsyncExecutionContext, wraps=context + spec=federated_language.framework.AsyncExecutionContext, wraps=context ) context = native_platform.NativeFederatedContext(mock_context) @@ -421,11 +424,11 @@ async def test_invoke_returns_result_materialized_sequentially( ( 'struct_nested', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -435,7 +438,7 @@ async def test_invoke_returns_result_materialized_sequentially( ), ( 'y', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('c', np.str_), ], @@ -452,11 +455,11 @@ async def test_invoke_returns_result_materialized_sequentially( ( 'struct_partially_empty', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -464,7 +467,7 @@ async def test_invoke_returns_result_materialized_sequentially( dict, ), ), - ('y', computation_types.StructWithPythonType([], dict)), + ('y', federated_language.StructWithPythonType([], dict)), ], dict, ) @@ -478,13 +481,13 @@ async def test_invoke_returns_result_materialized_concurrently( ): context = execution_contexts.create_async_local_cpp_execution_context() mock_context = mock.Mock( - spec=async_execution_context.AsyncExecutionContext, wraps=context + spec=federated_language.framework.AsyncExecutionContext, wraps=context ) context = native_platform.NativeFederatedContext(mock_context) with program_test_utils.assert_not_warns(RuntimeWarning): result = context.invoke(comp, arg) - actual_value = await value_reference.materialize_value(result) + actual_value = await federated_language.program.materialize_value(result) self.assertEqual(actual_value, expected_value) mock_context.invoke.assert_called_once() @@ -493,11 +496,11 @@ async def test_invoke_returns_result_materialized_concurrently( ( 'struct_nested', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -507,7 +510,7 @@ async def test_invoke_returns_result_materialized_concurrently( ), ( 'y', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('c', np.str_), ], @@ -524,11 +527,11 @@ async def test_invoke_returns_result_materialized_concurrently( ( 'struct_partially_empty', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ( 'x', - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ ('a', np.bool_), ('b', np.int32), @@ -536,7 +539,7 @@ async def test_invoke_returns_result_materialized_concurrently( dict, ), ), - ('y', computation_types.StructWithPythonType([], dict)), + ('y', federated_language.StructWithPythonType([], dict)), ], dict, ) @@ -550,16 +553,16 @@ async def test_invoke_returns_result_materialized_multiple( ): context = execution_contexts.create_async_local_cpp_execution_context() mock_context = mock.Mock( - spec=async_execution_context.AsyncExecutionContext, wraps=context + spec=federated_language.framework.AsyncExecutionContext, wraps=context ) context = native_platform.NativeFederatedContext(mock_context) with program_test_utils.assert_not_warns(RuntimeWarning): result = context.invoke(comp, arg) actual_value = await asyncio.gather( - value_reference.materialize_value(result), - value_reference.materialize_value(result), - value_reference.materialize_value(result), + federated_language.program.materialize_value(result), + federated_language.program.materialize_value(result), + federated_language.program.materialize_value(result), ) expected_value = [expected_value] * 3 @@ -570,7 +573,7 @@ async def test_invoke_returns_result_materialized_multiple( ( 'struct_unnamed_empty', _create_identity_federated_computation( - computation_types.StructWithPythonType([], list) + federated_language.StructWithPythonType([], list) ), [], [], @@ -578,7 +581,7 @@ async def test_invoke_returns_result_materialized_multiple( ( 'struct_named_empty', _create_identity_federated_computation( - computation_types.StructWithPythonType([], dict) + federated_language.StructWithPythonType([], dict) ), {}, {}, @@ -586,10 +589,10 @@ async def test_invoke_returns_result_materialized_multiple( ( 'struct_nested_empty', _create_identity_federated_computation( - computation_types.StructWithPythonType( + federated_language.StructWithPythonType( [ - ('x', computation_types.StructWithPythonType([], dict)), - ('y', computation_types.StructWithPythonType([], dict)), + ('x', federated_language.StructWithPythonType([], dict)), + ('y', federated_language.StructWithPythonType([], dict)), ], dict, ) @@ -603,13 +606,13 @@ async def test_invoke_returns_result_comp_not_called( ): context = execution_contexts.create_async_local_cpp_execution_context() mock_context = mock.Mock( - spec=async_execution_context.AsyncExecutionContext, wraps=context + spec=federated_language.framework.AsyncExecutionContext, wraps=context ) context = native_platform.NativeFederatedContext(mock_context) with program_test_utils.assert_not_warns(RuntimeWarning): result = context.invoke(comp, arg) - actual_value = await value_reference.materialize_value(result) + actual_value = await federated_language.program.materialize_value(result) self.assertEqual(actual_value, expected_value) mock_context.invoke.assert_not_called() @@ -618,23 +621,25 @@ async def test_invoke_returns_result_comp_not_called( ( 'federated_clients', _create_identity_federated_computation( - computation_types.FederatedType(np.int32, placements.CLIENTS) + federated_language.FederatedType( + np.int32, federated_language.CLIENTS + ) ), 1, ), ( 'function', _create_identity_federated_computation( - computation_types.FunctionType(np.int32, np.int32) + federated_language.FunctionType(np.int32, np.int32) ), _create_identity_federated_computation( - computation_types.TensorType(np.int32) + federated_language.TensorType(np.int32) ), ), ( 'placement', _create_identity_federated_computation( - computation_types.PlacementType() + federated_language.PlacementType() ), None, ), diff --git a/tensorflow_federated/python/program/program_state_manager.py b/tensorflow_federated/python/program/program_state_manager.py deleted file mode 100644 index 93365cf859..0000000000 --- a/tensorflow_federated/python/program/program_state_manager.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for saving and loading program state in a federated program.""" - -import abc -from typing import Generic, Optional, Union, TypeVar - -from tensorflow_federated.python.common_libs import serializable -from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference - -# ProgramStateManager's may save any value in addition to materializable values. -ProgramStateValue = Union[ - object, - value_reference.MaterializableValue, - # `tff.Serializable` objects are intended to be impemented by platform - # authors and used by program logic authors; it should not be common for - # program logic authors to implement `tff.Serializable` objects. - serializable.Serializable, -] -ProgramStateStructure = TypeVar( - 'ProgramStateStructure', - bound=structure_utils.Structure[ProgramStateValue], -) - - -class ProgramStateExistsError(Exception): - """Raised when the program state already exists.""" - - def __init__(self, *, version: int, path: str): - super().__init__( - f'Program state already exists for version [{version}] at path' - f' [{path}].' - ) - - -class ProgramStateNotFoundError(Exception): - """Raised when the program state cannot be found.""" - - def __init__(self, version: int): - super().__init__(f'No program state found for version: {version}.') - - -class ProgramStateManager(abc.ABC, Generic[ProgramStateStructure]): - """An interface for saving and loading program state in a federated program. - - A `tff.program.ProgramStateManager` is used to implement fault tolerance in a - federated program. The structure or type of the program state that is saved is - unknown at construction time and can change as the program runs. - """ - - @abc.abstractmethod - async def get_versions(self) -> Optional[list[int]]: - """Returns a list of saved versions or `None`. - - Returns: - A list of saved versions or `None` if there is no saved program state. - """ - raise NotImplementedError - - @abc.abstractmethod - async def load( - self, version: int, structure: ProgramStateStructure - ) -> ProgramStateStructure: - """Returns the saved program state for the given `version`. - - Args: - version: A integer representing the version of a saved program state. - structure: The structure of the saved program state for the given - `version` used to support serialization and deserailization of - user-defined classes in the structure. - - Raises: - ProgramStateNotFoundError: If there is no program state for the given - `version`. - """ - raise NotImplementedError - - async def load_latest( - self, structure: ProgramStateStructure - ) -> tuple[ProgramStateStructure, int]: - """Returns the latest saved program state and version or (`None`, 0). - - Args: - structure: The structure of the saved program state for the given - `version` used to support serialization and deserailization of - user-defined classes in the structure. - - Returns: - A tuple of the latest saved (program state, version) or (`None`, 0) if - there is no latest saved program state. - """ - versions = await self.get_versions() - if versions is None: - return None, 0 - latest_version = max(versions) - try: - return await self.load(latest_version, structure), latest_version - except ProgramStateNotFoundError: - return None, 0 - - @abc.abstractmethod - async def save( - self, program_state: ProgramStateStructure, version: int - ) -> None: - """Saves `program_state` for the given `version`. - - Args: - program_state: A `tff.program.ProgramStateStructure` to save. - version: A strictly increasing integer representing the version of a saved - `program_state`. - - Raises: - ProgramStateExistsError: If there is already program state for the given - `version`. - """ - raise NotImplementedError diff --git a/tensorflow_federated/python/program/program_state_manager_test.py b/tensorflow_federated/python/program/program_state_manager_test.py deleted file mode 100644 index fe9aa6fa38..0000000000 --- a/tensorflow_federated/python/program/program_state_manager_test.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional -import unittest -from unittest import mock - -from absl.testing import absltest - -from tensorflow_federated.python.program import program_state_manager - - -class _TestProgramStateManager( - program_state_manager.ProgramStateManager[ - program_state_manager.ProgramStateStructure - ] -): - """A test implementation of `tff.program.ProgramStateManager`. - - A `tff.program.ProgramStateManager` cannot be constructed directly because it - has abstract methods, this implementation exists to make it possible to - construct instances of `tff.program.ProgramStateManager` that can used as - stubs or mocked. - """ - - async def get_versions(self) -> Optional[list[int]]: - raise NotImplementedError - - async def load( - self, version: int, structure: program_state_manager.ProgramStateStructure - ) -> program_state_manager.ProgramStateStructure: - del version, structure # Unused. - raise NotImplementedError - - async def save( - self, - program_state: program_state_manager.ProgramStateStructure, - version: int, - ) -> None: - del program_state, version # Unused. - raise NotImplementedError - - -class ProgramStateManagerTest( - absltest.TestCase, unittest.IsolatedAsyncioTestCase -): - - async def test_load_latest_with_saved_program_state(self): - program_state_mngr = _TestProgramStateManager() - program_state_mngr.get_versions = mock.AsyncMock(return_value=[1, 2, 3]) - program_state_mngr.load = mock.AsyncMock(return_value='test3') - structure = 'test' - - (program_state, version) = await program_state_mngr.load_latest(structure) - - program_state_mngr.get_versions.assert_called_once_with() - program_state_mngr.load.assert_called_once_with(3, structure) - self.assertEqual(program_state, 'test3') - self.assertEqual(version, 3) - - async def test_load_latest_with_no_saved_program_state(self): - program_state_mngr = _TestProgramStateManager() - program_state_mngr.get_versions = mock.AsyncMock(return_value=None) - program_state_mngr.load = mock.AsyncMock() - structure = 'test' - - (program_state, version) = await program_state_mngr.load_latest(structure) - - program_state_mngr.get_versions.assert_called_once_with() - program_state_mngr.load.assert_not_called() - self.assertIsNone(program_state) - self.assertEqual(version, 0) - - async def test_load_latest_with_load_failure(self): - program_state_mngr = _TestProgramStateManager() - program_state_mngr.get_versions = mock.AsyncMock(return_value=[1, 2, 3]) - program_state_mngr.load = mock.AsyncMock( - side_effect=program_state_manager.ProgramStateNotFoundError(version=0) - ) - structure = 'test' - - (program_state, version) = await program_state_mngr.load_latest(structure) - - program_state_mngr.get_versions.assert_called_once_with() - program_state_mngr.load.assert_called_once_with(3, structure) - self.assertIsNone(program_state) - self.assertEqual(version, 0) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/program/program_test_utils.py b/tensorflow_federated/python/program/program_test_utils.py index 1f597f1c11..b0441043d5 100644 --- a/tensorflow_federated/python/program/program_test_utils.py +++ b/tensorflow_federated/python/program/program_test_utils.py @@ -22,43 +22,43 @@ import warnings import attrs +import federated_language import numpy as np import tensorflow as tf import tree from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.common_libs import serializable -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.program import value_reference T = TypeVar('T') class TestMaterializableValueReference( - value_reference.MaterializableValueReference + federated_language.program.MaterializableValueReference ): """A test implementation of `tff.program.MaterializableValueReference`.""" - def __init__(self, value: value_reference.MaterializedValue): + def __init__(self, value: federated_language.program.MaterializedValue): self._value = value if isinstance(value, bool): - self._type_signature = computation_types.TensorType(np.bool_) + self._type_signature = federated_language.TensorType(np.bool_) elif isinstance(value, int): - self._type_signature = computation_types.TensorType(np.int32) + self._type_signature = federated_language.TensorType(np.int32) elif isinstance(value, str): - self._type_signature = computation_types.TensorType(np.str_) + self._type_signature = federated_language.TensorType(np.str_) elif isinstance(value, list): - self._type_signature = computation_types.SequenceType(np.int32) + self._type_signature = federated_language.SequenceType(np.int32) else: raise NotImplementedError(f'Unexpected type found: {type(value)}.') @property - def type_signature(self) -> value_reference.MaterializableTypeSignature: + def type_signature( + self, + ) -> federated_language.program.MaterializableTypeSignature: return self._type_signature - async def get_value(self) -> value_reference.MaterializedValue: + async def get_value(self) -> federated_language.program.MaterializedValue: return self._value def __eq__(self, other: object) -> bool: @@ -68,13 +68,13 @@ def __eq__(self, other: object) -> bool: return NotImplemented if self._type_signature != other._type_signature: return False - if isinstance(self._type_signature, computation_types.SequenceType): + if isinstance(self._type_signature, federated_language.SequenceType): return list(self._value) == list(other._value) else: return self._value == other._value -class TestSerializable(serializable.Serializable): +class TestSerializable(federated_language.Serializable): """A test implementation of `tff.Serializable`.""" def __init__(self, a: int, b: int) -> None: @@ -110,7 +110,7 @@ class TestNamedTuple1(NamedTuple): a: bool b: int c: str - d: value_reference.MaterializableValueReference + d: federated_language.program.MaterializableValueReference e: Optional[TestSerializable] diff --git a/tensorflow_federated/python/program/release_manager.py b/tensorflow_federated/python/program/release_manager.py deleted file mode 100644 index f7acd12a98..0000000000 --- a/tensorflow_federated/python/program/release_manager.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2019, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for releasing values from a federated program.""" - -import abc -import asyncio -from collections.abc import Callable, Mapping, Sequence -import datetime -import typing -from typing import Generic, Optional, TypeVar, Union - -import attrs -import tree - -from tensorflow_federated.python.common_libs import py_typecheck -from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference - -# ReleaseManager's may release any value in addition to materializable values. -ReleasableValue = Union[ - object, - value_reference.MaterializableValue, -] -ReleasableStructure = TypeVar( - 'ReleasableStructure', - bound=structure_utils.Structure[ReleasableValue], -) -Key = TypeVar('Key') - - -class ReleasedValueNotFoundError(Exception): - """Raised when a released value cannot be found.""" - - def __init__(self, key: object): - super().__init__(f'No released value found for key: {key}.') - - -class ReleaseManager(abc.ABC, Generic[ReleasableStructure, Key]): - """An interface for releasing values from a federated program. - - A `tff.program.ReleaseManager` is used to release values from platform storage - to customer storage in a federated program. - """ - - @abc.abstractmethod - async def release(self, value: ReleasableStructure, key: Key) -> None: - """Releases `value` from a federated program. - - An implementation of this interface should be specific about the types of - `value` and `key` for this method and should document how the `key` will be - used. This allows a federated program to understand how to create a `key` - for the `value` before it is released. For example, a - `tff.program.ReleaseManager` that releases metrics keyed by a strictly - increasing integer might specify a `value` type of - `Mapping[str, ReleasableValue]` and a `key` type of `int`. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: A value used to reference the released `value`. - """ - raise NotImplementedError - - -class NotFilterableError(Exception): - """Raised when the structure cannot be filtered.""" - - -# Sentinel object used by the `tff.program.FilteringReleaseManager` to indicate -# that a subtree can be filtered when traversing structures of values and type -# signatures. -_FILTERED_SUBTREE = object() - - -class FilteringReleaseManager(ReleaseManager[ReleasableStructure, Key]): - """A `tff.program.ReleaseManager` that filters values before releasing them. - - A `tff.program.FilteringReleaseManager` is a utility for filtering values - before releasing the values and is used to release values from platform - storage to customer storage in a federated program. - - Values are filtered using a `filter_fn` and released to the `release_manager`. - - The `filter_fn` is a `Callable` that has a single parameter `path` and returns - a `bool`, and is used to filter values before they are released. A `path` is a - tuple of indices and/or keys which uniquely identifies the position of the - corresponding item in the `value`; `path` matches the expectations of the - `tree` library. - - The `filter_fn` is applied to the items in the structure but not the structure - itself. If all the items in a structure are filtered out, then the structure - will be filtered out as well. - - For example: - - ``` - filtering_manager = tff.program.FilteringReleaseManager( - release_manager=..., - filter_fn=..., - ) - - value = { - 'loss': 1.0, - 'accuracy': 0.5, - } - await filtering_manager.release(value, ...) - ``` - - If `filter_fn` is: - - * `lambda _: True` then the entire structure is released. - * `lambda _: False` then nothing is released. - * `lambda path: path == ('loss',)` then `{'loss': 1.0}` is released. - - Note: The path `()` corresponds to the root of the structure; because the - `filter_fn` is applied to the items in the structure but not the structure - itself, this path can be used to filter individual values from structures of - values. - - Important: Most `tff.program.ReleasableStructure` can be filtered, including - individual values, structures, and structures nested in `NamedTuple`s. - However, the fields of a `NamedTuple` cannot be filtered. - """ - - def __init__( - self, - release_manager: ReleaseManager[ReleasableStructure, Key], - filter_fn: Callable[[tuple[Union[str, int], ...]], bool], - ): - """Returns an initialized `tff.program.FilteringReleaseManager`. - - Args: - release_manager: A `tff.program.ReleaseManager` used to release values to. - filter_fn: A `Callable` used to filter values before they are released, - this function has a single parameter `path` and returns a `bool`. - """ - self._release_manager = release_manager - self._filter_fn = filter_fn - - async def release(self, value: ReleasableStructure, key: Key) -> None: - """Releases `value` from a federated program. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: A value used to reference the released `value`. - - Raises: - NotFilterableError: If the `value` cannot be filtered. - """ - - def _filter_value( - path: tuple[Union[str, int], ...], - subtree: ReleasableStructure, - ) -> Optional[Union[ReleasableStructure, type(_FILTERED_SUBTREE)]]: - """The function to apply when filtering the `value`. - - This function is meant to be used with `tree.traverse_with_path` to filter - the `value`. The function `tree.traverse_with_path` is used because the - traversal functions from `tree` apply a function to the structure and - the leaves, whereas the map functions only apply a function to the leaves. - Additionally, `path` is used to determine which parts of the structure to - filter. - - See https://tree.readthedocs.io/en/latest/api.html#tree.traverse for more - information. - - Args: - path: A tuple of indices and/or keys which uniquely identifies the - position of `subtree` in the `value`. - subtree: A substructure in `value`. - - Returns: - A filtered value or `_FILTERED_SUBTREE` if the entire structure was - filtered. - - Raises: - NotFilterableError: If `subtree` cannot be filtered. - """ - if tree.is_nested(subtree) and not attrs.has(type(subtree)): - # TODO: b/224484886 - Downcasting to all handled types. - subtree = typing.cast( - Union[Sequence[object], Mapping[str, object]], subtree - ) - if isinstance(subtree, Sequence): - elements = [x for x in subtree if x is not _FILTERED_SUBTREE] - if not elements: - return _FILTERED_SUBTREE - elif isinstance(subtree, py_typecheck.SupportsNamedTuple): - if len(subtree) != len(elements): - fields = list(type(subtree)._fields) - missing_fields = [ - k - for k, v in subtree._asdict().items() - if v is _FILTERED_SUBTREE - ] - raise NotFilterableError( - 'The fields of a `NamedTuple` cannot be filtered. Expected ' - f'{type(subtree)} to have fields {fields}, found it was ' - f'missing fields {missing_fields}.' - ) - - return type(subtree)(*elements) - else: - # Assumes the `Sequence` has a constructor that accepts `elements`, - # this is safe because `tree` makes the same assumption. - return type(subtree)(elements) # pytype: disable=wrong-arg-count - elif isinstance(subtree, Mapping): - items = [ - (k, v) for k, v in subtree.items() if v is not _FILTERED_SUBTREE - ] - if not items: - return _FILTERED_SUBTREE - else: - # Assumes the `Mapping` has a constructor that accepts `items`, - # this is safe because `tree` makes the same assumption. - return type(subtree)(items) # pytype: disable=wrong-arg-count - else: - raise NotImplementedError(f'Unexpected type found: {type(subtree)}.') - else: - if self._filter_fn(path): - return None - else: - return _FILTERED_SUBTREE - - filtered_value = tree.traverse_with_path( - _filter_value, value, top_down=False - ) - - if filtered_value is not _FILTERED_SUBTREE: - await self._release_manager.release(filtered_value, key=key) - - -class GroupingReleaseManager(ReleaseManager[ReleasableStructure, Key]): - """A `tff.program.ReleaseManager` that releases values to other release managers. - - A `tff.program.GroupingReleaseManager` is a utility for release values from a - federated program to a collection of other release managers and is used to - release values from platform storage to customer storage in a federated - program. - - Values are released using each of the `tff.program.ReleaseManager`s in the - given `release_managers`. - """ - - def __init__( - self, release_managers: Sequence[ReleaseManager[ReleasableStructure, Key]] - ): - """Returns an initialized `tff.program.GroupingReleaseManager`. - - Args: - release_managers: A sequence of `tff.program.ReleaseManager` used to - release values to. - - Raises: - ValueError: If `release_managers` is empty. - """ - if not release_managers: - raise ValueError('Expected `release_managers` to not be empty.') - - self._release_managers = release_managers - - async def release(self, value: ReleasableStructure, key: Key) -> None: - """Releases `value` from a federated program. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: A value used to reference the released `value`. - """ - await asyncio.gather( - *[m.release(value, key=key) for m in self._release_managers] - ) - - -class PeriodicReleaseManager(ReleaseManager[ReleasableStructure, Key]): - """A `tff.program.ReleaseManager` that releases values at regular intervals. - - A `tff.program.PeriodicReleaseManager` is a utility for releasing values at - regular intervals and is used to release values from platform storage to - customer storage in a federated program. - - The interval can be controlled at construction time by setting the - `periodicity`. The `periodicity` can be a positive integer or - `datetime.timedelta`. A `periodicity` of `3` means that every third value is - released to the `release_manager`, and invoking `release` ten times will - release the third, sixth, and ninth values. A `periodicity` of - `datetime.timedelta(hours=3)` means that three hours after the previously - released value the next value is released to the `release_manager`. - - Note: that a `periodicity` of one or a very small `datetime.timedelta` will - release every value, making the `tff.program.PeriodicReleaseManager` a noop - wrapper around the `release_manager`. - """ - - def __init__( - self, - release_manager: ReleaseManager[ReleasableStructure, Key], - periodicity: Union[int, datetime.timedelta], - ): - """Returns an initialized `tff.program.PeriodicReleaseManager`. - - Args: - release_manager: A `tff.program.ReleaseManager` used to release values to. - periodicity: The interval to release values. Must be a positive integer or - `datetime.timedelta`. - - Raises: - ValueError: If `periodicity` is not a positive integer or - `datetime.timedelta`. - """ - if (isinstance(periodicity, int) and periodicity < 1) or ( - isinstance(periodicity, datetime.timedelta) - and periodicity.total_seconds() < 1.0 - ): - raise ValueError( - 'Expected `periodicity` to be a positive integer or' - f' `datetime.timedelta`, found {periodicity}.' - ) - - self._release_manager = release_manager - self._periodicity = periodicity - if isinstance(periodicity, int): - self._count = 0 - elif isinstance(periodicity, datetime.timedelta): - self._timestamp = datetime.datetime.now() - else: - raise NotImplementedError( - f'Unexpected `periodicity` found: {type(periodicity)}.' - ) - - async def release(self, value: ReleasableStructure, key: Key) -> None: - """Releases `value` from a federated program. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: A value used to reference the released `value`. - """ - if isinstance(self._periodicity, int): - self._count += 1 - if self._count % self._periodicity == 0: - await self._release_manager.release(value, key=key) - elif isinstance(self._periodicity, datetime.timedelta): - now = datetime.datetime.now() - if now >= self._timestamp + self._periodicity: - self._timestamp = now - await self._release_manager.release(value, key=key) - else: - raise NotImplementedError( - f'Unexpected `periodicity` found: {type(self._periodicity)}.' - ) - - -class DelayedReleaseManager(ReleaseManager[ReleasableStructure, Key]): - """A `tff.program.ReleaseManager` that releases values after specified delay. - - A `tff.program.DelayedReleaseManager` is a utility for releasing values in a - federated program, where releases only take place after a specified delay - count. I.e., releases from platform storage to customer storage will take - place only after a certain number of instances that the `.release()` method is - called. After this delay, further calls to `.release()` will release values - (in accordance with the `release_manager` that was provided). - - For example, in a federated program that runs for a long time, one may want to - skip releasing values until the program has run for a sufficiently long-enough - period. - - The delay count is specified at construction time by setting the `delay` - argument (an integer). A `delay` of `3` means that all values will start to be - released once `release` has been invoked at least three times. - - Note: that a `delay` of one will release every value, making the - `tff.program.DelayedReleaseManager` a noop wrapper around the - `release_manager`. - """ - - def __init__( - self, - release_manager: ReleaseManager[ReleasableStructure, Key], - delay: int, - ): - """Returns an initialized `tff.program.DelayedReleaseManager`. - - Args: - release_manager: A `tff.program.ReleaseManager` used to release values to. - delay: The delay duration before releasing values. Must be a positive - integer. - - Raises: - ValueError: If `delay` is not positive. - """ - if delay < 1: - raise ValueError(f'The `delay` must be positive but found {delay}.') - - self._release_manager = release_manager - self._count = 0 - self._delay = delay - - async def release(self, value: ReleasableStructure, key: Key) -> None: - """Releases `value` from a federated program. - - Args: - value: A `tff.program.ReleasableStructure` to release. - key: A value used to reference the released `value`. - """ - self._count += 1 - if self._count >= self._delay: - await self._release_manager.release(value, key=key) diff --git a/tensorflow_federated/python/program/release_manager_test.py b/tensorflow_federated/python/program/release_manager_test.py deleted file mode 100644 index 21a6c91696..0000000000 --- a/tensorflow_federated/python/program/release_manager_test.py +++ /dev/null @@ -1,580 +0,0 @@ -# Copyright 2022, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -import unittest -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -import tree - -from tensorflow_federated.python.program import program_test_utils -from tensorflow_federated.python.program import release_manager - - -class FilteringReleaseManagerTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - def test_init_does_not_raise_type_error(self): - mock_release_mngr = mock.AsyncMock(spec=release_manager.ReleaseManager) - filter_fn = lambda _: True - - try: - release_manager.FilteringReleaseManager(mock_release_mngr, filter_fn) - except TypeError: - self.fail('Raised `TypeError` unexpectedly.') - - @parameterized.named_parameters( - # materialized values - ('none_filter_none', None, lambda _: True, None), - ('bool_filter_none', True, lambda _: True, True), - ('int_filter_none', 1, lambda _: True, 1), - ('str_filter_none', 'abc', lambda _: True, 'abc'), - ('numpy_int_filter_none', np.int32(1), lambda _: True, np.int32(1)), - ( - 'numpy_array_filter_none', - np.array([1] * 1, np.int32), - lambda _: True, - np.array([1] * 1, np.int32), - ), - # materializable value references - ( - 'value_reference_tensor_filter_none', - program_test_utils.TestMaterializableValueReference(1), - lambda _: True, - program_test_utils.TestMaterializableValueReference(1), - ), - # serializable values - ( - 'serializable_value_filter_none', - program_test_utils.TestSerializable(1, 2), - lambda _: True, - program_test_utils.TestSerializable(1, 2), - ), - # other values - ( - 'attrs_filter_none', - program_test_utils.TestAttrs(1, 2), - lambda _: True, - program_test_utils.TestAttrs(1, 2), - ), - # structures - ( - 'list_filter_none', - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - lambda _: True, - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - ), - ( - 'list_filter_some', - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - lambda path: path == (1,) or path == (2,), - [1, 'abc'], - ), - ( - 'dict_ordered_filter_none', - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - lambda _: True, - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ( - 'dict_unordered_filter_none', - { - 'c': 'abc', - 'b': 1, - 'a': True, - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - lambda _: True, - { - 'c': 'abc', - 'b': 1, - 'a': True, - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ( - 'dict_ordered_filter_some', - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - lambda path: path == ('b',) or path == ('c',), - {'b': 1, 'c': 'abc'}, - ), - ( - 'dict_unordered_filter_some', - { - 'c': 'abc', - 'b': 1, - 'a': True, - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - lambda path: path == ('b',) or path == ('a',), - {'b': 1, 'a': True}, - ), - ( - 'named_tuple_filter_none', - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - lambda _: True, - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - ), - ) - async def test_release_filters_and_delegates_value( - self, value, filter_fn, expected_value - ): - mock_release_mngr = mock.AsyncMock(spec=release_manager.ReleaseManager) - release_mngr = release_manager.FilteringReleaseManager( - mock_release_mngr, filter_fn - ) - key = 1 - - await release_mngr.release(value, key=key) - - mock_release_mngr.release.assert_called_once() - call = mock_release_mngr.release.mock_calls[0] - _, args, kwargs = call - (actual_value,) = args - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - self.assertEqual(kwargs, {'key': key}) - - @parameterized.named_parameters( - # materialized values - ('none', None), - ('bool', True), - ('int', 1), - ('str', 'abc'), - ('numpy_int', np.int32(1)), - ('numpy_array', np.array([1] * 3, np.int32)), - # materializable value references - ( - 'value_reference_tensor', - program_test_utils.TestMaterializableValueReference(1), - ), - ( - 'value_reference_sequence', - program_test_utils.TestMaterializableValueReference([1, 2, 3]), - ), - # serializable values - ('serializable_value', program_test_utils.TestSerializable(1, 2)), - # other values - ('attrs', program_test_utils.TestAttrs(1, 2)), - # structures - ( - 'list', - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - ), - ('list_empty', []), - ( - 'list_nested', - [ - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - ), - ( - 'dict_ordered', - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ( - 'dict_unordered', - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ('dict_empty', {}), - ( - 'dict_nested', - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - ), - ( - 'named_tuple', - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - ), - ( - 'named_tuple_nested', - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - ), - ) - async def test_release_filters_and_does_not_delegate_value(self, value): - mock_release_mngr = mock.AsyncMock(spec=release_manager.ReleaseManager) - filter_fn = lambda _: False - release_mngr = release_manager.FilteringReleaseManager( - mock_release_mngr, filter_fn - ) - key = 1 - - await release_mngr.release(value, key=key) - - mock_release_mngr.release.assert_not_called() - - @parameterized.named_parameters( - ( - 'list_filter_none', - [True, 1, 'abc', [], [2]], - lambda _: True, - [True, 1, 'abc', [2]], - ), - ( - 'list_filter_some', - [True, 1, 'abc', [], [2]], - lambda path: path != (4, 0), - [True, 1, 'abc'], - ), - ( - 'dict_filter_none', - {'a': True, 'b': 1, 'c': 'abc', 'd': {}, 'e': {'a': 2}}, - lambda _: True, - {'a': True, 'b': 1, 'c': 'abc', 'e': {'a': 2}}, - ), - ( - 'dict_filter_some', - {'a': True, 'b': 1, 'c': 'abc', 'd': {}, 'e': {'a': 2}}, - lambda path: path != ('e', 'a'), - {'a': True, 'b': 1, 'c': 'abc'}, - ), - ) - async def test_release_filters_and_does_not_delegate_empty_structures( - self, value, filter_fn, expected_value - ): - mock_release_mngr = mock.AsyncMock(spec=release_manager.ReleaseManager) - release_mngr = release_manager.FilteringReleaseManager( - mock_release_mngr, filter_fn - ) - key = 1 - - await release_mngr.release(value, key=key) - - mock_release_mngr.release.assert_called_once() - call = mock_release_mngr.release.mock_calls[0] - _, args, kwargs = call - (actual_value,) = args - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - self.assertEqual(kwargs, {'key': key}) - - @parameterized.named_parameters( - # structures - ( - 'named_tuple_filter_some', - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - lambda path: path == ('b',) or path == ('c',), - ), - ) - async def test_release_raises_not_filterable_error(self, value, filter_fn): - mock_release_mngr = mock.AsyncMock(spec=release_manager.ReleaseManager) - release_mngr = release_manager.FilteringReleaseManager( - mock_release_mngr, filter_fn - ) - key = 1 - - with self.assertRaises(release_manager.NotFilterableError): - await release_mngr.release(value, key=key) - - -class GroupingReleaseManagerTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - def test_init_does_not_raise_type_error(self): - release_mngrs = [ - mock.AsyncMock(spec=release_manager.ReleaseManager), - mock.AsyncMock(spec=release_manager.ReleaseManager), - mock.AsyncMock(spec=release_manager.ReleaseManager), - ] - - try: - release_manager.GroupingReleaseManager(release_mngrs) - except TypeError: - self.fail('Raised `TypeError` unexpectedly.') - - def test_init_raises_value_error_with_release_manager_empty(self): - release_mngrs = [] - - with self.assertRaises(ValueError): - release_manager.GroupingReleaseManager(release_mngrs) - - async def test_release_delegates_value(self): - release_mngrs = [ - mock.AsyncMock(spec=release_manager.ReleaseManager), - mock.AsyncMock(spec=release_manager.ReleaseManager), - mock.AsyncMock(spec=release_manager.ReleaseManager), - ] - release_mngr = release_manager.GroupingReleaseManager(release_mngrs) - value = 1 - key = 1 - - await release_mngr.release(value, key=key) - - for mock_release_mngr in release_mngrs: - mock_release_mngr.release.assert_called_once_with(value, key=key) - - -class PeriodicReleaseManagerTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - ('int_negative', -1), - ('int_zero', 0), - ('timedelta_negative', datetime.timedelta(seconds=-1)), - ('timedelta_zero', datetime.timedelta()), - ) - def test_init_raises_value_error_with_period(self, periodicity): - mock_release_mngr = mock.AsyncMock( - spec=release_manager.ReleaseManager, set_spec=True - ) - - with self.assertRaises(ValueError): - release_manager.PeriodicReleaseManager(mock_release_mngr, periodicity) - - @parameterized.named_parameters( - ('all_releases', 1, 10, 10), - ('some_releases', 2, 10, 5), - ('last_release', 10, 10, 1), - ('drops_trailing_releases', 3, 10, 3), - ('drops_all_releases', 11, 10, 0), - ) - async def test_release_delegates_value_with_periodicity_int( - self, periodicity, total, expected_count - ): - mock_release_mngr = mock.AsyncMock( - spec=release_manager.ReleaseManager, set_spec=True - ) - release_mngr = release_manager.PeriodicReleaseManager( - mock_release_mngr, periodicity - ) - value = 1 - key = 1 - - for _ in range(total): - await release_mngr.release(value, key=key) - - self.assertEqual(mock_release_mngr.release.call_count, expected_count) - mock_release_mngr.release.assert_has_calls( - [mock.call(value, key=key)] * expected_count - ) - - @parameterized.named_parameters( - ( - 'all_releases', - datetime.timedelta(seconds=1), - [datetime.timedelta(seconds=i) for i in range(1, 11)], - 10, - ), - ( - 'some_releases', - datetime.timedelta(seconds=2), - [datetime.timedelta(seconds=i) for i in range(1, 11)], - 5, - ), - ( - 'last_release', - datetime.timedelta(seconds=10), - [datetime.timedelta(seconds=i) for i in range(1, 11)], - 1, - ), - ( - 'drops_trailing_releases', - datetime.timedelta(seconds=3), - [datetime.timedelta(seconds=i) for i in range(1, 11)], - 3, - ), - ( - 'drops_all_releases', - datetime.timedelta(seconds=11), - [datetime.timedelta(seconds=i) for i in range(1, 11)], - 0, - ), - ) - async def test_release_delegates_value_with_periodicity_timedelta( - self, periodicity, timedeltas, expected_count - ): - mock_release_mngr = mock.AsyncMock( - spec=release_manager.ReleaseManager, set_spec=True - ) - release_mngr = release_manager.PeriodicReleaseManager( - mock_release_mngr, periodicity - ) - value = 1 - key = 1 - - now = datetime.datetime.now() - with mock.patch.object(datetime, 'datetime') as mock_datetime: - mock_datetime.now.side_effect = [now + x for x in timedeltas] - - for _ in timedeltas: - await release_mngr.release(value, key=key) - - self.assertEqual(mock_release_mngr.release.call_count, expected_count) - mock_release_mngr.release.assert_has_calls( - [mock.call(value, key=key)] * expected_count - ) - - -class DelayedReleaseManagerTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - ('int_negative', -1), - ('int_zero', 0), - ) - def test_init_raises_value_error_with_bad_delay(self, delay): - mock_release_mngr = mock.AsyncMock( - spec=release_manager.ReleaseManager, set_spec=True - ) - - with self.assertRaises(ValueError): - release_manager.DelayedReleaseManager(mock_release_mngr, delay) - - @parameterized.named_parameters( - ('all_releases', 1, 10, 10), - ('some_releases', 3, 10, 8), - ('last_release', 10, 10, 1), - ('drops_all_releases', 11, 10, 0), - ) - async def test_release_delegates_value_with_delay( - self, delay, total, expected_count - ): - mock_release_mngr = mock.AsyncMock( - spec=release_manager.ReleaseManager, set_spec=True - ) - release_mngr = release_manager.DelayedReleaseManager( - mock_release_mngr, delay - ) - value = 1 - key = 1 - - for _ in range(total): - await release_mngr.release(value, key=key) - - self.assertEqual(mock_release_mngr.release.call_count, expected_count) - mock_release_mngr.release.assert_has_calls( - [mock.call(value, key=key)] * expected_count - ) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/program/serialization_utils.py b/tensorflow_federated/python/program/serialization_utils.py index c21abee5b8..413923d99d 100644 --- a/tensorflow_federated/python/program/serialization_utils.py +++ b/tensorflow_federated/python/program/serialization_utils.py @@ -31,12 +31,10 @@ import struct from typing import Protocol, TypeVar +import federated_language +from federated_language.proto import computation_pb2 import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 -from tensorflow_federated.python.common_libs import serializable -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import type_serialization from tensorflow_federated.python.program import structure_utils @@ -139,7 +137,7 @@ def unpack_sequence_from( return sequence, length_size + length -def pack_serializable(value: serializable.Serializable) -> bytes: +def pack_serializable(value: federated_language.Serializable) -> bytes: """Packs a `tff.Serializable` as bytes.""" module_name_bytes = pack_str(type(value).__module__) class_name_bytes = pack_str(type(value).__name__) @@ -155,7 +153,7 @@ def pack_serializable(value: serializable.Serializable) -> bytes: def unpack_serializable_from( buffer: bytes, offset: int = 0 -) -> tuple[serializable.Serializable, int]: +) -> tuple[federated_language.Serializable, int]: """Unpacks a `tff.Serializable` from bytes. Args: @@ -188,9 +186,9 @@ def unpack_serializable_from( ) -def pack_type_spec(type_spec: computation_types.Type) -> bytes: +def pack_type_spec(type_spec: federated_language.Type) -> bytes: """Packs a `tff.Type` as bytes.""" - proto = type_serialization.serialize_type(type_spec) + proto = federated_language.framework.serialize_type(type_spec) type_bytes = proto.SerializeToString() length_bytes = _pack_length(type_bytes) return length_bytes + type_bytes @@ -198,7 +196,7 @@ def pack_type_spec(type_spec: computation_types.Type) -> bytes: def unpack_type_spec_from( buffer: bytes, offset: int = 0 -) -> tuple[computation_types.Type, int]: +) -> tuple[federated_language.Type, int]: """Unpacks a `tff.Type` from bytes. Args: @@ -212,7 +210,7 @@ def unpack_type_spec_from( offset += length_size type_spec_bytes, *_ = struct.unpack_from(f'!{length}s', buffer, offset=offset) proto = computation_pb2.Type.FromString(type_spec_bytes) - type_spec = type_serialization.deserialize_type(proto) + type_spec = federated_language.framework.deserialize_type(proto) return type_spec, length_size + length diff --git a/tensorflow_federated/python/program/serialization_utils_test.py b/tensorflow_federated/python/program/serialization_utils_test.py index 97f1aefbf3..b5ef9ce8df 100644 --- a/tensorflow_federated/python/program/serialization_utils_test.py +++ b/tensorflow_federated/python/program/serialization_utils_test.py @@ -17,10 +17,10 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.program import program_test_utils from tensorflow_federated.python.program import serialization_utils @@ -214,7 +214,7 @@ def test_unpack_serializable_from_raises_struct_error_with_corrupt_bytes( class SerializationUtilsTypeSpecTest(parameterized.TestCase): def test_pack_and_unpack_type_spec(self): - value = computation_types.TensorType(np.int32) + value = federated_language.TensorType(np.int32) value_bytes = serialization_utils.pack_type_spec(value) actual_value, actual_size = serialization_utils.unpack_type_spec_from( @@ -226,7 +226,7 @@ def test_pack_and_unpack_type_spec(self): self.assertEqual(actual_size, expected_size) def test_pack_and_unpack_type_spec_with_offset(self): - value = computation_types.TensorType(np.int32) + value = federated_language.TensorType(np.int32) offset = 100 value_bytes = serialization_utils.pack_type_spec(value) @@ -244,7 +244,7 @@ def test_pack_and_unpack_type_spec_with_offset(self): ('too_large', 1), ) def test_unpack_type_spec_from_raises_struct_error_with_offset(self, offset): - value = computation_types.TensorType(np.int32) + value = federated_language.TensorType(np.int32) value_bytes = serialization_utils.pack_type_spec(value) with self.assertRaises(struct.error): @@ -257,7 +257,7 @@ def test_unpack_type_spec_from_raises_struct_error_with_offset(self, offset): def test_unpack_type_spec_from_raises_struct_error_with_corrupt_bytes( self, corrupt_fn ): - value = computation_types.TensorType(np.int32) + value = federated_language.TensorType(np.int32) value_bytes = serialization_utils.pack_type_spec(value) corrupt_bytes = corrupt_fn(value_bytes) diff --git a/tensorflow_federated/python/program/tensorboard_release_manager.py b/tensorflow_federated/python/program/tensorboard_release_manager.py index f125eccef5..4694d8fa23 100644 --- a/tensorflow_federated/python/program/tensorboard_release_manager.py +++ b/tensorflow_federated/python/program/tensorboard_release_manager.py @@ -16,16 +16,17 @@ import os from typing import Union +import federated_language import numpy as np import tensorflow as tf -from tensorflow_federated.python.program import release_manager from tensorflow_federated.python.program import structure_utils -from tensorflow_federated.python.program import value_reference class TensorBoardReleaseManager( - release_manager.ReleaseManager[release_manager.ReleasableStructure, int] + federated_language.program.ReleaseManager[ + federated_language.program.ReleasableStructure, int + ] ): """A `tff.program.ReleaseManager` that releases values to TensorBoard. @@ -67,7 +68,7 @@ def __init__(self, summary_dir: Union[str, os.PathLike[str]]): self._summary_writer = tf.summary.create_file_writer(summary_dir) async def release( - self, value: release_manager.ReleasableStructure, key: int + self, value: federated_language.program.ReleasableStructure, key: int ) -> None: """Releases `value` from a federated program. @@ -76,12 +77,14 @@ async def release( key: A integer used to reference the released `value`; `key` represents a step in a federated program. """ - materialized_value = await value_reference.materialize_value(value) + materialized_value = await federated_language.program.materialize_value( + value + ) flattened_value = structure_utils.flatten_with_name(materialized_value) def _normalize( - value: value_reference.MaterializedValue, - ) -> value_reference.MaterializedValue: + value: federated_language.program.MaterializedValue, + ) -> federated_language.program.MaterializedValue: if isinstance(value, tf.data.Dataset): value = list(value) return value diff --git a/tensorflow_federated/python/program/value_reference.py b/tensorflow_federated/python/program/value_reference.py deleted file mode 100644 index d6fcbbf9ec..0000000000 --- a/tensorflow_federated/python/program/value_reference.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines abstract interfaces for representing references to values. - -These abstract interfaces provide the capability to handle values without -requiring them to be materialized as Python objects. Instances of these -abstract interfaces represent values of type `tff.TensorType` and can be placed -on the server, elements of structures that are placed on the server, or -unplaced. -""" - -import abc -import asyncio -from collections.abc import Iterable -from typing import Union - -import numpy as np - -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import typed_object -from tensorflow_federated.python.program import structure_utils - -MaterializableTypeSignature = Union[ - computation_types.TensorType, - computation_types.SequenceType, -] -_MaterializedArrayValue = Union[ - # Python types - bool, - int, - float, - complex, - str, - bytes, - # Numpy types - np.generic, - np.ndarray, -] -MaterializedValue = Union[ - _MaterializedArrayValue, - Iterable[_MaterializedArrayValue], -] -MaterializedStructure = structure_utils.Structure[MaterializedValue] -MaterializableValue = Union[ - MaterializedValue, - 'MaterializableValueReference', -] -MaterializableStructure = structure_utils.Structure[MaterializableValue] - - -class MaterializableValueReference(abc.ABC, typed_object.TypedObject): - """An abstract interface representing references to server-placed values.""" - - @property - @abc.abstractmethod - def type_signature(self) -> MaterializableTypeSignature: - """The `tff.Type` of this object.""" - raise NotImplementedError - - @abc.abstractmethod - async def get_value(self) -> MaterializedValue: - """Returns the referenced value. - - The Python type of the referenced value depends on the `type_signature`: - - | TFF Type | Python Type | - | ------------------ | -------------------------------------------------- | - | `tff.TensorType` | `bool`, `int`, `float`, `complex`, `str`, `bytes`, | - | | `np.generic`, or `np.ndarray` | - | `tff.SequenceType` | `Iterable` of any Python type corresponding to a | - | | `tff.TensorType` | - """ - raise NotImplementedError - - -async def materialize_value( - value: MaterializableStructure, -) -> MaterializedStructure: - """Materializes the `tff.program.MaterializableValueReference`s in `value`. - - Args: - value: A `tff.program.MaterializableStructure` to materialize. - - Returns: - A `tff.program.MaterializedStructure`. - """ - - async def _materialize(value: MaterializableValue) -> MaterializedValue: - if isinstance(value, MaterializableValueReference): - return await value.get_value() - else: - return value - - flattened_value = structure_utils.flatten(value) - materialized_value = await asyncio.gather( - *[_materialize(v) for v in flattened_value] - ) - return structure_utils.unflatten_as(value, materialized_value) diff --git a/tensorflow_federated/python/program/value_reference_test.py b/tensorflow_federated/python/program/value_reference_test.py deleted file mode 100644 index 421997ebf4..0000000000 --- a/tensorflow_federated/python/program/value_reference_test.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from absl.testing import absltest -from absl.testing import parameterized -import numpy as np -import tree - -from tensorflow_federated.python.program import program_test_utils -from tensorflow_federated.python.program import value_reference - - -class MaterializeValueTest( - parameterized.TestCase, unittest.IsolatedAsyncioTestCase -): - - @parameterized.named_parameters( - # materialized values - ('bool', True, True), - ('int', 1, 1), - ('str', 'abc', 'abc'), - ('numpy_generic', np.int32(1), np.int32(1)), - ('numpy_array', np.array([1] * 3, np.int32), np.array([1] * 3, np.int32)), - # materializable value references - ( - 'materializable_value_reference_tensor', - program_test_utils.TestMaterializableValueReference(1), - 1, - ), - ( - 'materializable_value_reference_sequence', - program_test_utils.TestMaterializableValueReference([1, 2, 3]), - [1, 2, 3], - ), - # serializable values - ( - 'serializable_value', - program_test_utils.TestSerializable(1, 2), - program_test_utils.TestSerializable(1, 2), - ), - # other values - ( - 'attrs', - program_test_utils.TestAttrs(1, 2), - program_test_utils.TestAttrs(1, 2), - ), - # structures - ( - 'list', - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [ - True, - 1, - 'abc', - 2, - program_test_utils.TestSerializable(3, 4), - ], - ), - ('list_empty', [], []), - ( - 'list_nested', - [ - [ - True, - 1, - 'abc', - program_test_utils.TestMaterializableValueReference(2), - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - [ - [ - True, - 1, - 'abc', - 2, - program_test_utils.TestSerializable(3, 4), - ], - [5], - ], - ), - ( - 'dict_ordered', - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ( - 'dict_unordered', - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - { - 'c': True, - 'b': 1, - 'a': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - ), - ('dict_empty', {}, {}), - ( - 'dict_nested', - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': program_test_utils.TestMaterializableValueReference(2), - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - { - 'x': { - 'a': True, - 'b': 1, - 'c': 'abc', - 'd': 2, - 'e': program_test_utils.TestSerializable(3, 4), - }, - 'y': {'a': 5}, - }, - ), - ( - 'named_tuple', - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=2, - e=program_test_utils.TestSerializable(3, 4), - ), - ), - ( - 'named_tuple_nested', - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=program_test_utils.TestMaterializableValueReference(2), - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - program_test_utils.TestNamedTuple3( - x=program_test_utils.TestNamedTuple1( - a=True, - b=1, - c='abc', - d=2, - e=program_test_utils.TestSerializable(3, 4), - ), - y=program_test_utils.TestNamedTuple2(a=5), - ), - ), - ) - async def test_returns_value(self, value, expected_value): - actual_value = await value_reference.materialize_value(value) - - tree.assert_same_structure(actual_value, expected_value) - program_test_utils.assert_same_key_order(actual_value, expected_value) - actual_value = program_test_utils.to_python(actual_value) - expected_value = program_test_utils.to_python(expected_value) - self.assertEqual(actual_value, expected_value) - - -if __name__ == '__main__': - absltest.main() diff --git a/tensorflow_federated/python/simulation/BUILD b/tensorflow_federated/python/simulation/BUILD index b4a9f0649d..49dd9bc225 100644 --- a/tensorflow_federated/python/simulation/BUILD +++ b/tensorflow_federated/python/simulation/BUILD @@ -40,14 +40,9 @@ py_library( deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", - "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -58,12 +53,9 @@ py_test( deps = [ ":iterative_process_compositions", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/learning/templates:learning_process", + "@federated_language//federated_language", ], ) @@ -84,10 +76,8 @@ py_library( srcs = ["training_loop.py"], deps = [ "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/core/templates:iterative_process", - "//tensorflow_federated/python/program:program_state_manager", - "//tensorflow_federated/python/program:release_manager", + "@federated_language//federated_language", ], ) @@ -97,11 +87,7 @@ py_test( srcs = ["training_loop_test.py"], deps = [ ":training_loop", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/federated_context:federated_computation", - "//tensorflow_federated/python/core/impl/federated_context:intrinsics", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/simulation/baselines/BUILD b/tensorflow_federated/python/simulation/baselines/BUILD index 7da0f77293..52dc1ecddc 100644 --- a/tensorflow_federated/python/simulation/baselines/BUILD +++ b/tensorflow_federated/python/simulation/baselines/BUILD @@ -71,8 +71,8 @@ py_library( name = "task_data", srcs = ["task_data.py"], deps = [ - "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/simulation/datasets:client_data", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/simulation/baselines/task_data.py b/tensorflow_federated/python/simulation/baselines/task_data.py index 4881482878..afd3433de3 100644 --- a/tensorflow_federated/python/simulation/baselines/task_data.py +++ b/tensorflow_federated/python/simulation/baselines/task_data.py @@ -17,15 +17,16 @@ from collections.abc import Callable from typing import Any, Optional, Union +import federated_language import numpy as np import tensorflow as tf -from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.simulation.datasets import client_data CentralOrClientData = Union[tf.data.Dataset, client_data.ClientData] PreprocessFnType = Union[ - Callable[[tf.data.Dataset], tf.data.Dataset], computation_base.Computation + Callable[[tf.data.Dataset], tf.data.Dataset], + federated_language.framework.Computation, ] diff --git a/tensorflow_federated/python/simulation/datasets/BUILD b/tensorflow_federated/python/simulation/datasets/BUILD index 92fdd025c3..772306a0f9 100644 --- a/tensorflow_federated/python/simulation/datasets/BUILD +++ b/tensorflow_federated/python/simulation/datasets/BUILD @@ -227,7 +227,7 @@ py_library( deps = [ "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation", - "//tensorflow_federated/python/core/impl/computation:computation_base", + "@federated_language//federated_language", ], ) @@ -259,7 +259,7 @@ py_test( deps = [ ":file_per_user_client_data", "//tensorflow_federated/python/core/backends/native:execution_contexts", - "//tensorflow_federated/python/core/impl/computation:computation_base", + "@federated_language//federated_language", ], ) @@ -280,8 +280,7 @@ py_test( ":from_tensor_slices_client_data", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_types", - "//tensorflow_federated/python/core/impl/computation:computation_base", - "//tensorflow_federated/python/core/impl/types:computation_types", + "@federated_language//federated_language", ], ) diff --git a/tensorflow_federated/python/simulation/datasets/client_data.py b/tensorflow_federated/python/simulation/datasets/client_data.py index 87a18d6d42..da92282aa6 100644 --- a/tensorflow_federated/python/simulation/datasets/client_data.py +++ b/tensorflow_federated/python/simulation/datasets/client_data.py @@ -18,12 +18,12 @@ from typing import Any, Optional, Union from absl import logging +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.computation import computation_base class IncompatiblePreprocessFnError(TypeError): @@ -134,7 +134,7 @@ def create_tf_dataset_for_client(self, client_id: str) -> tf.data.Dataset: This function will create a dataset for a given client, given that `client_id` is contained in the `client_ids` property of the `ClientData`. - Unlike `create_dataset`, this method need not be serializable. + Unlike `create_dataset`, this method need not be federated_language. Args: client_id: The string client_id for the desired client. @@ -253,7 +253,7 @@ def preprocess( Raises: IncompatiblePreprocessFnError: If `preprocess_fn` is a `tff.Computation`. """ - if isinstance(preprocess_fn, computation_base.Computation): + if isinstance(preprocess_fn, federated_language.framework.Computation): raise IncompatiblePreprocessFnError() return PreprocessClientData(self, preprocess_fn) @@ -282,7 +282,9 @@ def from_clients_and_tf_fn( Returns: A `ClientData` object. """ - if isinstance(serializable_dataset_fn, computation_base.Computation): + if isinstance( + serializable_dataset_fn, federated_language.framework.Computation + ): raise TypeError( 'The input serializable_dataset_fn cannot be a tff.Computation, as it' ' must be serializable within the context of a tf.function.' diff --git a/tensorflow_federated/python/simulation/datasets/file_per_user_client_data.py b/tensorflow_federated/python/simulation/datasets/file_per_user_client_data.py index f8af860eb8..35e78c216d 100644 --- a/tensorflow_federated/python/simulation/datasets/file_per_user_client_data.py +++ b/tensorflow_federated/python/simulation/datasets/file_per_user_client_data.py @@ -83,7 +83,8 @@ def create_tf_dataset_for_client(self, client_id: str) -> tf.data.Dataset: This function will create a dataset for a given client if `client_id` is contained in the `client_ids` property of the `FilePerUserClientData`. - Unlike `self.serializable_dataset_fn`, this method is not serializable. + Unlike `self.serializable_dataset_fn`, this method is not + federated_language. Args: client_id: The string identifier for the desired client. diff --git a/tensorflow_federated/python/simulation/datasets/file_per_user_client_data_test.py b/tensorflow_federated/python/simulation/datasets/file_per_user_client_data_test.py index ee91506d31..452e6010c0 100644 --- a/tensorflow_federated/python/simulation/datasets/file_per_user_client_data_test.py +++ b/tensorflow_federated/python/simulation/datasets/file_per_user_client_data_test.py @@ -24,10 +24,10 @@ import os.path import tempfile +import federated_language import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts -from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.simulation.datasets import file_per_user_client_data # A fake columnar dataset of (user id, value 1, value 2, value 3), roughly @@ -211,7 +211,7 @@ def test_create_tf_dataset_from_all_clients(self): def test_dataset_computation(self): data = self._create_fake_client_data() self.assertIsInstance( - data.dataset_computation, computation_base.Computation + data.dataset_computation, federated_language.framework.Computation ) # Iterate over each client, ensuring we received a tf.data.Dataset with the # correct data. diff --git a/tensorflow_federated/python/simulation/datasets/from_tensor_slices_client_data_test.py b/tensorflow_federated/python/simulation/datasets/from_tensor_slices_client_data_test.py index 9b260b64b2..aea1cbd5c6 100644 --- a/tensorflow_federated/python/simulation/datasets/from_tensor_slices_client_data_test.py +++ b/tensorflow_federated/python/simulation/datasets/from_tensor_slices_client_data_test.py @@ -16,13 +16,12 @@ import copy from absl.testing import parameterized +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_types -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.simulation.datasets import from_tensor_slices_client_data TEST_DATA = { @@ -261,11 +260,13 @@ def get_flat_dataset(seed): def test_dataset_computation_where_client_data_is_tensors(self): client_data = from_tensor_slices_client_data.TestClientData(TEST_DATA) dataset_computation = client_data.dataset_computation - self.assertIsInstance(dataset_computation, computation_base.Computation) + self.assertIsInstance( + dataset_computation, federated_language.framework.Computation + ) - expected_dataset_comp_type_signature = computation_types.FunctionType( - computation_types.TensorType(np.str_), - computation_types.SequenceType( + expected_dataset_comp_type_signature = federated_language.FunctionType( + federated_language.TensorType(np.str_), + federated_language.SequenceType( tensorflow_types.to_type( (client_data.element_type_structure.dtype, (2,)) ) @@ -291,12 +292,14 @@ def test_dataset_computation_where_client_data_is_tuples(self): TEST_DATA_WITH_TUPLES ) dataset_computation = client_data.dataset_computation - self.assertIsInstance(dataset_computation, computation_base.Computation) - expected_dataset_comp_type_signature = computation_types.FunctionType( - computation_types.TensorType(np.str_), - computation_types.SequenceType(( - computation_types.TensorType(np.int32), - computation_types.TensorType(np.int32), + self.assertIsInstance( + dataset_computation, federated_language.framework.Computation + ) + expected_dataset_comp_type_signature = federated_language.FunctionType( + federated_language.TensorType(np.str_), + federated_language.SequenceType(( + federated_language.TensorType(np.int32), + federated_language.TensorType(np.int32), )), ) @@ -320,10 +323,12 @@ def test_dataset_computation_where_client_data_is_ordered_dicts(self): TEST_DATA_WITH_ORDEREDDICTS ) dataset_computation = client_data.dataset_computation - self.assertIsInstance(dataset_computation, computation_base.Computation) - expected_dataset_comp_type_signature = computation_types.FunctionType( - computation_types.TensorType(np.str_), - computation_types.SequenceType( + self.assertIsInstance( + dataset_computation, federated_language.framework.Computation + ) + expected_dataset_comp_type_signature = federated_language.FunctionType( + federated_language.TensorType(np.str_), + federated_language.SequenceType( collections.OrderedDict([ ( 'x', diff --git a/tensorflow_federated/python/simulation/datasets/sql_client_data.py b/tensorflow_federated/python/simulation/datasets/sql_client_data.py index 78d7916a63..16e83ff8b8 100644 --- a/tensorflow_federated/python/simulation/datasets/sql_client_data.py +++ b/tensorflow_federated/python/simulation/datasets/sql_client_data.py @@ -166,7 +166,7 @@ def create_tf_dataset_for_client(self, client_id: str): This function will create a dataset for a given client if `client_id` is contained in the `client_ids` property of the `SQLClientData`. Unlike - `self.serializable_dataset_fn`, this method is not serializable. + `self.serializable_dataset_fn`, this method is not federated_language. Args: client_id: The string identifier for the desired client. diff --git a/tensorflow_federated/python/simulation/iterative_process_compositions.py b/tensorflow_federated/python/simulation/iterative_process_compositions.py index 4df2fcb654..aca0f20259 100644 --- a/tensorflow_federated/python/simulation/iterative_process_compositions.py +++ b/tensorflow_federated/python/simulation/iterative_process_compositions.py @@ -13,14 +13,9 @@ # limitations under the License. """Library of compositional helpers for iterative processes.""" +import federated_language from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.learning.templates import learning_process @@ -38,9 +33,9 @@ class MultipleMatchingSequenceTypesError(TypeError): def compose_dataset_computation_with_computation( - dataset_computation: computation_base.Computation, - computation_body: computation_base.Computation, -) -> computation_base.Computation: + dataset_computation: federated_language.framework.Computation, + computation_body: federated_language.framework.Computation, +) -> federated_language.framework.Computation: """Builds a new `tff.Computation` which constructs datasets on clients. Given a `tff.Computation` that returns a `tf.data.Dataset`, and a @@ -96,11 +91,15 @@ def compose_dataset_computation_with_computation( `computation_body` declares more than one sequence parameter matching the expected dataset type. """ - py_typecheck.check_type(dataset_computation, computation_base.Computation) - py_typecheck.check_type(computation_body, computation_base.Computation) + py_typecheck.check_type( + dataset_computation, federated_language.framework.Computation + ) + py_typecheck.check_type( + computation_body, federated_language.framework.Computation + ) dataset_return_type = dataset_computation.type_signature.result - if not isinstance(dataset_return_type, computation_types.SequenceType): + if not isinstance(dataset_return_type, federated_language.SequenceType): raise TypeError( 'Expected a `tff.SequenceType` to be returned from ' '`dataset_computation`; found {} instead.'.format(dataset_return_type) @@ -118,23 +117,25 @@ def compose_dataset_computation_with_computation( comp_body_param_type = computation_body.type_signature.parameter def is_desired_federated_sequence(t): - if not isinstance(t, computation_types.FederatedType): + if not isinstance(t, federated_language.FederatedType): return False return t.member.is_assignable_from(dataset_return_type) if is_desired_federated_sequence(comp_body_param_type): # Single argument that matches, we compose in a straightforward manner. - new_param_type = computation_types.FederatedType( - dataset_computation.type_signature.parameter, placements.CLIENTS + new_param_type = federated_language.FederatedType( + dataset_computation.type_signature.parameter, federated_language.CLIENTS ) - @federated_computation.federated_computation(new_param_type) + @federated_language.federated_computation(new_param_type) def new_computation(param): - datasets_on_clients = intrinsics.federated_map(dataset_computation, param) + datasets_on_clients = federated_language.federated_map( + dataset_computation, param + ) return computation_body(datasets_on_clients) return new_computation - elif isinstance(comp_body_param_type, computation_types.StructType): + elif isinstance(comp_body_param_type, federated_language.StructType): # If the computation has multiple arguments we need to search over them # recursively to find the one that matches the type signature of # dataset_computation's result. @@ -144,15 +145,15 @@ def new_computation(param): dataset_index_path = None # Federated version of the dataset_computation's argument type signature to # use in the final computation type. - federated_param_type = computation_types.FederatedType( - dataset_computation.type_signature.parameter, placements.CLIENTS + federated_param_type = federated_language.FederatedType( + dataset_computation.type_signature.parameter, federated_language.CLIENTS ) # Tracks all sequence types encountered in the recursive search for the # error message in case the desired argument is not found. sequence_types = [] def build_new_param_type( - struct_param_type: computation_types.StructType, index_path + struct_param_type: federated_language.StructType, index_path ): """Builds a new struct parameter type. @@ -178,8 +179,8 @@ def build_new_param_type( structure.iter_elements(struct_param_type) ): if isinstance( - elem_type, computation_types.FederatedType - ) and isinstance(elem_type.member, computation_types.SequenceType): + elem_type, federated_language.FederatedType + ) and isinstance(elem_type.member, federated_language.SequenceType): sequence_types.append(elem_type.member) if is_desired_federated_sequence(elem_type): @@ -193,13 +194,13 @@ def build_new_param_type( ) dataset_index_path = index_path + [idx] new_param_elements.append((elem_name, federated_param_type)) - elif isinstance(elem_type, computation_types.StructType): + elif isinstance(elem_type, federated_language.StructType): new_param_elements.append( (elem_name, build_new_param_type(elem_type, index_path + [idx])) ) else: new_param_elements.append((elem_name, elem_type)) - return computation_types.StructType(new_param_elements) + return federated_language.StructType(new_param_elements) new_param_type = build_new_param_type(comp_body_param_type, []) @@ -245,14 +246,14 @@ def map_at_path(param, index_path, depth, computation): if idx != index_path[depth]: ret_param.append(elem) elif depth == len(index_path) - 1: - ret_param.append(intrinsics.federated_map(computation, elem)) + ret_param.append(federated_language.federated_map(computation, elem)) else: ret_param.append( map_at_path(elem, index_path, depth + 1, computation) ) return ret_param - @federated_computation.federated_computation(new_param_type) + @federated_language.federated_computation(new_param_type) def new_computation(param): return computation_body( map_at_path(param, dataset_index_path, 0, dataset_computation) @@ -270,7 +271,7 @@ def new_computation(param): def compose_dataset_computation_with_iterative_process( - dataset_computation: computation_base.Computation, + dataset_computation: federated_language.framework.Computation, process: iterative_process.IterativeProcess, ) -> iterative_process.IterativeProcess: """Builds a new iterative process which constructs datasets on clients. @@ -327,11 +328,13 @@ def compose_dataset_computation_with_iterative_process( TypeError: If the arguments are of the wrong types, or their TFF type signatures are incompatible with the specification of this function. """ - py_typecheck.check_type(dataset_computation, computation_base.Computation) + py_typecheck.check_type( + dataset_computation, federated_language.framework.Computation + ) py_typecheck.check_type(process, iterative_process.IterativeProcess) dataset_return_type = dataset_computation.type_signature.result - if not isinstance(dataset_return_type, computation_types.SequenceType): + if not isinstance(dataset_return_type, federated_language.SequenceType): raise TypeError( 'Expected a `tff.SequenceType` to be returned from ' '`dataset_computation`; found {} instead.'.format(dataset_return_type) @@ -347,9 +350,9 @@ def compose_dataset_computation_with_iterative_process( ) init_fn = process.initialize - if type_analysis.contains( + if federated_language.framework.type_contains( init_fn.type_signature.result, - lambda x: isinstance(x, computation_types.SequenceType), + lambda x: isinstance(x, federated_language.SequenceType), ): raise TypeError( 'Cannot construct a new iterative process if a dataset is ' @@ -366,7 +369,7 @@ def compose_dataset_computation_with_iterative_process( def compose_dataset_computation_with_learning_process( - dataset_computation: computation_base.Computation, + dataset_computation: federated_language.framework.Computation, process: learning_process.LearningProcess, ) -> learning_process.LearningProcess: """Builds a new learning process which constructs datasets on clients. diff --git a/tensorflow_federated/python/simulation/iterative_process_compositions_test.py b/tensorflow_federated/python/simulation/iterative_process_compositions_test.py index 34ee50b794..7d4c657d11 100644 --- a/tensorflow_federated/python/simulation/iterative_process_compositions_test.py +++ b/tensorflow_federated/python/simulation/iterative_process_compositions_test.py @@ -15,14 +15,11 @@ import collections from absl.testing import absltest +import federated_language import numpy as np import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.learning.templates import learning_process from tensorflow_federated.python.simulation import iterative_process_compositions @@ -45,11 +42,13 @@ def _create_federated_int_dataset_identity_iterative_process(): def create_dataset(): return tf.data.Dataset.range(5) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init(): - return intrinsics.federated_eval(create_dataset, placements.CLIENTS) + return federated_language.federated_eval( + create_dataset, federated_language.CLIENTS + ) - @federated_computation.federated_computation(init.type_signature.result) + @federated_language.federated_computation(init.type_signature.result) def next_fn(x): return x @@ -61,26 +60,28 @@ def _create_stateless_int_dataset_reduction_iterative_process(): def make_zero(): return tf.cast(0, tf.int64) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init(): - return intrinsics.federated_eval(make_zero, placements.SERVER) + return federated_language.federated_eval( + make_zero, federated_language.SERVER + ) @tensorflow_computation.tf_computation( - computation_types.SequenceType(np.int64) + federated_language.SequenceType(np.int64) ) def reduce_dataset(x): return x.reduce(tf.cast(0, tf.int64), lambda x, y: x + y) - @federated_computation.federated_computation(( + @federated_language.federated_computation(( init.type_signature.result, - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.int64), federated_language.CLIENTS ), )) def next_fn(server_state, client_data): del server_state # Unused - return intrinsics.federated_sum( - intrinsics.federated_map(reduce_dataset, client_data) + return federated_language.federated_sum( + federated_language.federated_map(reduce_dataset, client_data) ) return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn) @@ -93,13 +94,15 @@ def _create_stateless_int_vector_unknown_dim_dataset_reduction_iterative_process def make_zero(): return tf.reshape(tf.cast(0, tf.int64), shape=[1]) - @federated_computation.federated_computation() + @federated_language.federated_computation() def init(): - return intrinsics.federated_eval(make_zero, placements.SERVER) + return federated_language.federated_eval( + make_zero, federated_language.SERVER + ) @tensorflow_computation.tf_computation( - computation_types.SequenceType( - computation_types.TensorType(np.int64, shape=[None]) + federated_language.SequenceType( + federated_language.TensorType(np.int64, shape=[None]) ) ) def reduce_dataset(x): @@ -107,22 +110,22 @@ def reduce_dataset(x): tf.cast(tf.constant([0]), tf.int64), lambda x, y: x + tf.reduce_sum(y) ) - @federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.TensorType(np.int64, shape=[None]), - placements.SERVER, + @federated_language.federated_computation( + federated_language.FederatedType( + federated_language.TensorType(np.int64, shape=[None]), + federated_language.SERVER, ), - computation_types.FederatedType( - computation_types.SequenceType( - computation_types.TensorType(np.int64, shape=[None]) + federated_language.FederatedType( + federated_language.SequenceType( + federated_language.TensorType(np.int64, shape=[None]) ), - placements.CLIENTS, + federated_language.CLIENTS, ), ) def next_fn(server_state, client_data): del server_state # Unused - return intrinsics.federated_sum( - intrinsics.federated_map(reduce_dataset, client_data) + return federated_language.federated_sum( + federated_language.federated_map(reduce_dataset, client_data) ) return iterative_process.IterativeProcess(initialize_fn=init, next_fn=next_fn) @@ -151,10 +154,10 @@ def int_identity(x): return x -@federated_computation.federated_computation( +@federated_language.federated_computation( np.int32, - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.int64), federated_language.CLIENTS ), np.float32, ) @@ -162,12 +165,13 @@ def test_int64_sequence_struct_computation(a, dataset, b): return a, dataset, b -@federated_computation.federated_computation( +@federated_language.federated_computation( np.int32, - computation_types.StructType([ + federated_language.StructType([ np.int64, - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.int64), + federated_language.CLIENTS, ), np.float64, ]), @@ -177,28 +181,29 @@ def test_int64_sequence_nested_struct_computation(a, dataset, b): return a, dataset, b -@federated_computation.federated_computation( - computation_types.StructType([ - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS +@federated_language.federated_computation( + federated_language.StructType([ + federated_language.FederatedType( + federated_language.SequenceType(np.int64), + federated_language.CLIENTS, ), ]), - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.int64), federated_language.CLIENTS ), ) def test_int64_sequence_multiple_matching_federated_types_computation(a, b): return a, b -@federated_computation.federated_computation( - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS +@federated_language.federated_computation( + federated_language.FederatedType( + federated_language.SequenceType(np.int64), federated_language.CLIENTS ) ) def test_int64_sequence_computation(dataset): del dataset - return intrinsics.federated_value(5, placements.SERVER) + return federated_language.federated_value(5, federated_language.SERVER) class ConstructDatasetsOnClientsComputationTest(absltest.TestCase): @@ -224,7 +229,7 @@ def test_raises_computation_not_returning_dataset(self): ) def test_raises_computation_no_dataset_parameter(self): - no_dataset_comp = federated_computation.federated_computation( + no_dataset_comp = federated_language.federated_computation( lambda x: x, [np.int32] ) with self.assertRaises( @@ -243,9 +248,13 @@ def test_raises_mismatched_dataset_comp_return_type_and_sequence_type(self): ) def test_mutates_comp_accepting_only_dataset(self): - expected_new_next_type_signature = computation_types.FunctionType( - parameter=computation_types.FederatedType(np.str_, placements.CLIENTS), - result=computation_types.FederatedType(np.int32, placements.SERVER), + expected_new_next_type_signature = federated_language.FunctionType( + parameter=federated_language.FederatedType( + np.str_, federated_language.CLIENTS + ), + result=federated_language.FederatedType( + np.int32, federated_language.SERVER + ), ) new_comp = iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, test_int64_sequence_computation @@ -255,18 +264,19 @@ def test_mutates_comp_accepting_only_dataset(self): ) def test_mutates_comp_accepting_dataset_in_second_index(self): - expected_new_next_type_signature = computation_types.FunctionType( + expected_new_next_type_signature = federated_language.FunctionType( parameter=collections.OrderedDict( a=np.int32, - dataset=computation_types.FederatedType( - np.str_, placements.CLIENTS + dataset=federated_language.FederatedType( + np.str_, federated_language.CLIENTS ), b=np.float32, ), result=( np.int32, - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.int64), + federated_language.CLIENTS, ), np.float32, ), @@ -288,12 +298,14 @@ def test_raises_computation_with_multiple_federated_types(self): ) def test_mutates_comp_accepting_deeply_nested_dataset(self): - expected_new_next_type_signature = computation_types.FunctionType( + expected_new_next_type_signature = federated_language.FunctionType( parameter=collections.OrderedDict( a=np.int32, - dataset=computation_types.StructType([ + dataset=federated_language.StructType([ np.int64, - computation_types.FederatedType(np.str_, placements.CLIENTS), + federated_language.FederatedType( + np.str_, federated_language.CLIENTS + ), np.float64, ]), b=np.float32, @@ -363,16 +375,16 @@ def test_raises_iterproc_if_dataset_is_returned_by_init(self): def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() - expected_new_next_type_signature = computation_types.FunctionType( + expected_new_next_type_signature = federated_language.FunctionType( collections.OrderedDict( - server_state=computation_types.FederatedType( - np.int64, placements.SERVER + server_state=federated_language.FederatedType( + np.int64, federated_language.SERVER ), - client_data=computation_types.FederatedType( - np.str_, placements.CLIENTS + client_data=federated_language.FederatedType( + np.str_, federated_language.CLIENTS ), ), - computation_types.FederatedType(np.int64, placements.SERVER), + federated_language.FederatedType(np.int64, federated_language.SERVER), ) new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process( @@ -387,18 +399,19 @@ def test_mutates_iterproc_with_parameter_assignable_from_result(self): iterproc = ( _create_stateless_int_vector_unknown_dim_dataset_reduction_iterative_process() ) - expected_new_next_type_signature = computation_types.FunctionType( + expected_new_next_type_signature = federated_language.FunctionType( collections.OrderedDict( - server_state=computation_types.FederatedType( - computation_types.TensorType(np.int64, shape=[None]), - placements.SERVER, + server_state=federated_language.FederatedType( + federated_language.TensorType(np.int64, shape=[None]), + federated_language.SERVER, ), - client_data=computation_types.FederatedType( - np.str_, placements.CLIENTS + client_data=federated_language.FederatedType( + np.str_, federated_language.CLIENTS ), ), - computation_types.FederatedType( - computation_types.TensorType(np.int64, shape=[1]), placements.SERVER + federated_language.FederatedType( + federated_language.TensorType(np.int64, shape=[1]), + federated_language.SERVER, ), ) @@ -417,8 +430,8 @@ def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self): new_param_elements = [old_param_type[0], np.int32, old_param_type[1]] - @federated_computation.federated_computation( - computation_types.StructType(new_param_elements) + @federated_language.federated_computation( + federated_language.StructType(new_param_elements) ) def new_next(param): return iterproc.next([param[0], param[2]]) @@ -426,13 +439,17 @@ def new_next(param): iterproc_with_dataset_as_third_elem = iterative_process.IterativeProcess( iterproc.initialize, new_next ) - expected_new_next_type_signature = computation_types.FunctionType( + expected_new_next_type_signature = federated_language.FunctionType( [ - computation_types.FederatedType(np.int64, placements.SERVER), + federated_language.FederatedType( + np.int64, federated_language.SERVER + ), np.int32, - computation_types.FederatedType(np.str_, placements.CLIENTS), + federated_language.FederatedType( + np.str_, federated_language.CLIENTS + ), ], - computation_types.FederatedType(np.int64, placements.SERVER), + federated_language.FederatedType(np.int64, federated_language.SERVER), ) new_iterproc = iterative_process_compositions.compose_dataset_computation_with_iterative_process( @@ -453,21 +470,26 @@ def test_returns_iterative_process_with_same_non_next_type_signatures(self): def make_zero(): return tf.cast(0, tf.int64) - @federated_computation.federated_computation() + @federated_language.federated_computation() def initialize_fn(): - return intrinsics.federated_eval(make_zero, placements.SERVER) + return federated_language.federated_eval( + make_zero, federated_language.SERVER + ) - @federated_computation.federated_computation(( + @federated_language.federated_computation(( initialize_fn.type_signature.result, - computation_types.FederatedType( - computation_types.SequenceType(np.int64), placements.CLIENTS + federated_language.FederatedType( + federated_language.SequenceType(np.int64), + federated_language.CLIENTS, ), )) def next_fn(server_state, client_data): del client_data return learning_process.LearningProcessOutput( state=server_state, - metrics=intrinsics.federated_value((), placements.SERVER), + metrics=federated_language.federated_value( + (), federated_language.SERVER + ), ) @tensorflow_computation.tf_computation(np.int64) diff --git a/tensorflow_federated/python/simulation/training_loop.py b/tensorflow_federated/python/simulation/training_loop.py index 850edaacee..a80b61b80b 100644 --- a/tensorflow_federated/python/simulation/training_loop.py +++ b/tensorflow_federated/python/simulation/training_loop.py @@ -20,12 +20,10 @@ from typing import Any, Optional from absl import logging +import federated_language from tensorflow_federated.python.common_libs import structure -from tensorflow_federated.python.core.impl.computation import computation_base from tensorflow_federated.python.core.templates import iterative_process -from tensorflow_federated.python.program import program_state_manager as program_state_manager_lib -from tensorflow_federated.python.program import release_manager as release_manager_lib MetricsType = MutableMapping[str, Any] @@ -38,7 +36,7 @@ def _run_training( - training_fn: computation_base.Computation, + training_fn: federated_language.framework.Computation, client_selection_fn: Callable[[int], Any], state: Any, round_num: int, @@ -80,11 +78,11 @@ def run_training_process( evaluation_selection_fn: Optional[Callable[[int], Any]] = None, rounds_per_evaluation: int = 1, program_state_manager: Optional[ - program_state_manager_lib.ProgramStateManager + federated_language.program.ProgramStateManager ] = None, rounds_per_saving_program_state: int = 1, metrics_managers: Optional[ - Iterable[release_manager_lib.ReleaseManager] + Iterable[federated_language.program.ReleaseManager] ] = None, ): """Runs a federated `training_process`. diff --git a/tensorflow_federated/python/simulation/training_loop_test.py b/tensorflow_federated/python/simulation/training_loop_test.py index d471ba6eda..48b9e7a969 100644 --- a/tensorflow_federated/python/simulation/training_loop_test.py +++ b/tensorflow_federated/python/simulation/training_loop_test.py @@ -17,42 +17,44 @@ from absl.testing import absltest from absl.testing import parameterized +import federated_language import numpy as np -from tensorflow_federated.python.core.impl.computation import computation_base -from tensorflow_federated.python.core.impl.federated_context import federated_computation -from tensorflow_federated.python.core.impl.federated_context import intrinsics -from tensorflow_federated.python.core.impl.types import computation_types -from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.simulation import training_loop -@federated_computation.federated_computation +@federated_language.federated_computation def _test_init_fn(): - return intrinsics.federated_value(0, placements.SERVER) + return federated_language.federated_value(0, federated_language.SERVER) -@federated_computation.federated_computation([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), +@federated_language.federated_computation([ + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def _test_next_fn(state, client_data): del state, client_data # Unused - updated_state = intrinsics.federated_value(1, placements.SERVER) + updated_state = federated_language.federated_value( + 1, federated_language.SERVER + ) metrics = collections.OrderedDict([('metric', 1.0)]) - output = intrinsics.federated_value(metrics, placements.SERVER) + output = federated_language.federated_value( + metrics, federated_language.SERVER + ) return updated_state, output -@federated_computation.federated_computation([ - computation_types.FederatedType(np.int32, placements.SERVER), - computation_types.FederatedType(np.int32, placements.CLIENTS), +@federated_language.federated_computation([ + federated_language.FederatedType(np.int32, federated_language.SERVER), + federated_language.FederatedType(np.int32, federated_language.CLIENTS), ]) def _test_evaluation_fn(state, client_data): del state, client_data # Unused metrics = collections.OrderedDict([('metric', 2.0)]) - output = intrinsics.federated_value(metrics, placements.SERVER) + output = federated_language.federated_value( + metrics, federated_language.SERVER + ) return output @@ -145,7 +147,7 @@ def test_evaluation_fns_called(self, total_rounds, rounds_per_evaluation): training_process.next.return_value = ('update', {'metric': 1.0}) training_selection_fn = mock.MagicMock() evaluation_fn = mock.create_autospec( - computation_base.Computation, return_value={'metric': 1.0} + federated_language.framework.Computation, return_value={'metric': 1.0} ) evaluation_selection_fn = mock.MagicMock() evaluation_selection_fn.return_value = [0] diff --git a/third_party/federated_language/BUILD b/third_party/federated_language/BUILD new file mode 100644 index 0000000000..31d4c71147 --- /dev/null +++ b/third_party/federated_language/BUILD @@ -0,0 +1,6 @@ +package( + default_applicable_licenses = ["//:package_license"], + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) diff --git a/third_party/federated_language/proto_library_loads.patch b/third_party/federated_language/proto_library_loads.patch new file mode 100644 index 0000000000..302d7d946f --- /dev/null +++ b/third_party/federated_language/proto_library_loads.patch @@ -0,0 +1,13 @@ +diff --git federated_language/proto/BUILD federated_language/proto/BUILD +index d23d7cf..004dcb2 100644 +--- federated_language/proto/BUILD ++++ federated_language/proto/BUILD +@@ -1,6 +1,5 @@ +-load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +-load("@protobuf//bazel:proto_library.bzl", "proto_library") +-load("@protobuf//bazel:py_proto_library.bzl", "py_proto_library") ++load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library") ++load("@rules_cc//cc:defs.bzl", "cc_proto_library") + load("@rules_python//python:defs.bzl", "py_library") + + package( diff --git a/third_party/federated_language/python_deps.patch b/third_party/federated_language/python_deps.patch new file mode 100644 index 0000000000..e7fe4f91a1 --- /dev/null +++ b/third_party/federated_language/python_deps.patch @@ -0,0 +1,687 @@ +diff --git federated_language/common_libs/BUILD federated_language/common_libs/BUILD +index 10806a0..730c71e 100644 +--- federated_language/common_libs/BUILD ++++ federated_language/common_libs/BUILD +@@ -23,7 +23,6 @@ py_library( + name = "golden", + testonly = True, + srcs = ["golden.py"], +- deps = ["@pypi//absl_py"], + ) + + py_test( +@@ -43,29 +42,19 @@ py_test( + "golden_test_goldens/test_check_string_succeeds.expected", + "golden_test_goldens/test_check_string_updates.expected", + ], +- deps = [ +- ":golden", +- "@pypi//absl_py", +- ], ++ deps = [":golden"], + ) + + py_library( + name = "py_typecheck", + srcs = ["py_typecheck.py"], +- deps = [ +- "@pypi//attrs", +- "@pypi//typing_extensions", +- ], + ) + + py_test( + name = "py_typecheck_test", + size = "small", + srcs = ["py_typecheck_test.py"], +- deps = [ +- ":py_typecheck", +- "@pypi//absl_py", +- ], ++ deps = [":py_typecheck"], + ) + + py_library( +@@ -80,10 +69,7 @@ py_test( + name = "retrying_test", + size = "small", + srcs = ["retrying_test.py"], +- deps = [ +- ":retrying", +- "@pypi//absl_py", +- ], ++ deps = [":retrying"], + ) + + py_library( +@@ -94,39 +80,25 @@ py_library( + py_library( + name = "structure", + srcs = ["structure.py"], +- deps = [ +- ":py_typecheck", +- "@pypi//attrs", +- "@pypi//dm_tree", +- ], ++ deps = [":py_typecheck"], + ) + + py_test( + name = "structure_test", + size = "small", + srcs = ["structure_test.py"], +- deps = [ +- ":structure", +- "@pypi//absl_py", +- "@pypi//attrs", +- ], ++ deps = [":structure"], + ) + + py_library( + name = "tracing", + srcs = ["tracing.py"], +- deps = [ +- ":py_typecheck", +- "@pypi//absl_py", +- ], ++ deps = [":py_typecheck"], + ) + + py_test( + name = "tracing_test", + size = "small", + srcs = ["tracing_test.py"], +- deps = [ +- ":tracing", +- "@pypi//absl_py", +- ], ++ deps = [":tracing"], + ) +diff --git federated_language/compiler/BUILD federated_language/compiler/BUILD +index b5b594a..a8a07eb 100644 +--- federated_language/compiler/BUILD ++++ federated_language/compiler/BUILD +@@ -29,8 +29,6 @@ py_library( + "//federated_language/proto:array_py_pb2", + "//federated_language/types:array_shape", + "//federated_language/types:dtype_utils", +- "@pypi//ml_dtypes", +- "@pypi//numpy", + ], + ) + +@@ -41,9 +39,6 @@ py_test( + ":array", + "//federated_language/proto:array_py_pb2", + "//federated_language/proto:data_type_py_pb2", +- "@pypi//absl_py", +- "@pypi//ml_dtypes", +- "@pypi//numpy", + ], + ) + +@@ -123,8 +118,6 @@ py_test( + "//federated_language/types:placements", + "//federated_language/types:type_analysis", + "//federated_language/types:type_test_utils", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -139,7 +132,6 @@ py_library( + "//federated_language/types:computation_types", + "//federated_language/types:placements", + "@protobuf//:protobuf_python", +- "@pypi//numpy", + ], + ) + +@@ -158,7 +150,6 @@ py_library( + "//federated_language/types:type_serialization", + "//federated_language/types:typed_object", + "@protobuf//:protobuf_python", +- "@pypi//numpy", + ], + ) + +@@ -180,10 +171,6 @@ py_test( + "//federated_language/types:placements", + "//federated_language/types:type_serialization", + "@protobuf//:protobuf_python", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//ml_dtypes", +- "@pypi//numpy", + ], + ) + +@@ -207,8 +194,6 @@ py_test( + "//federated_language/types:computation_types", + "//federated_language/types:type_factory", + "//federated_language/types:type_serialization", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -220,7 +205,6 @@ py_library( + "//federated_language/types:computation_types", + "//federated_language/types:placements", + "//federated_language/types:type_factory", +- "@pypi//numpy", + ], + ) + +@@ -228,10 +212,7 @@ py_test( + name = "intrinsic_defs_test", + size = "small", + srcs = ["intrinsic_defs_test.py"], +- deps = [ +- ":intrinsic_defs", +- "@pypi//absl_py", +- ], ++ deps = [":intrinsic_defs"], + ) + + py_library( +@@ -257,8 +238,6 @@ py_test( + "//federated_language/proto:computation_py_pb2", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -288,7 +267,5 @@ py_test( + ":tree_analysis", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) +diff --git federated_language/computation/BUILD federated_language/computation/BUILD +index 0c6151f..58f930f 100644 +--- federated_language/computation/BUILD ++++ federated_language/computation/BUILD +@@ -58,8 +58,6 @@ py_test( + "//federated_language/types:computation_types", + "//federated_language/types:type_serialization", + "//federated_language/types:type_test_utils", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -86,8 +84,6 @@ py_test( + "//federated_language/compiler:computation_factory", + "//federated_language/context_stack:context_stack_impl", + "//federated_language/types:computation_types", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -117,8 +113,6 @@ py_test( + "//federated_language/proto:computation_py_pb2", + "//federated_language/types:computation_types", + "//federated_language/types:type_serialization", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -143,8 +137,6 @@ py_test( + ":function_utils", + "//federated_language/common_libs:structure", + "//federated_language/types:computation_types", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -170,7 +162,5 @@ py_test( + "//federated_language/types:computation_types", + "//federated_language/types:type_conversions", + "//federated_language/types:type_serialization", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) +diff --git federated_language/context_stack/BUILD federated_language/context_stack/BUILD +index 7bf1405..1b8a63e 100644 +--- federated_language/context_stack/BUILD ++++ federated_language/context_stack/BUILD +@@ -49,7 +49,6 @@ py_test( + deps = [ + ":context_stack_impl", + ":context_stack_test_utils", +- "@pypi//absl_py", + ], + ) + +@@ -59,7 +58,6 @@ py_library( + deps = [ + ":context_base", + ":context_stack_impl", +- "@pypi//absl_py", + ], + ) + +@@ -69,7 +67,6 @@ py_test( + deps = [ + ":context_stack_impl", + ":context_stack_test_utils", +- "@pypi//absl_py", + ], + ) + +@@ -86,7 +83,6 @@ py_test( + deps = [ + ":context_stack_impl", + ":get_context_stack", +- "@pypi//absl_py", + ], + ) + +@@ -113,7 +109,6 @@ py_test( + ":context_stack_impl", + ":context_stack_test_utils", + ":set_default_context", +- "@pypi//absl_py", + ], + ) + +diff --git federated_language/execution_contexts/BUILD federated_language/execution_contexts/BUILD +index 9b5916d..0a90a09 100644 +--- federated_language/execution_contexts/BUILD ++++ federated_language/execution_contexts/BUILD +@@ -37,7 +37,6 @@ py_library( + "//federated_language/types:computation_types", + "//federated_language/types:type_conversions", + "//federated_language/types:typed_object", +- "@pypi//dm_tree", + ], + ) + +@@ -48,7 +47,6 @@ py_test( + deps = [ + ":async_execution_context", + "//federated_language/executors:executors_errors", +- "@pypi//absl_py", + ], + ) + +@@ -68,7 +66,6 @@ py_test( + deps = [ + ":compiler_pipeline", + "//federated_language/computation:computation_base", +- "@pypi//absl_py", + ], + ) + +diff --git federated_language/executors/BUILD federated_language/executors/BUILD +index 65fea33..2dcaa45 100644 +--- federated_language/executors/BUILD ++++ federated_language/executors/BUILD +@@ -37,8 +37,6 @@ py_test( + "//federated_language/common_libs:structure", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +diff --git federated_language/federated_context/BUILD federated_language/federated_context/BUILD +index ace1073..f6c1975 100644 +--- federated_language/federated_context/BUILD ++++ federated_language/federated_context/BUILD +@@ -44,8 +44,6 @@ py_test( + "//federated_language/context_stack:get_context_stack", + "//federated_language/context_stack:runtime_error_context", + "//federated_language/types:computation_types", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -77,8 +75,6 @@ py_test( + "//federated_language/context_stack:context_stack_impl", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -107,8 +103,6 @@ py_test( + "//federated_language/computation:function_utils", + "//federated_language/context_stack:context_stack_impl", + "//federated_language/types:computation_types", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -131,7 +125,6 @@ py_library( + "//federated_language/types:placements", + "//federated_language/types:type_analysis", + "//federated_language/types:type_factory", +- "@pypi//numpy", + ], + ) + +@@ -148,8 +141,6 @@ py_test( + "//federated_language/context_stack:context_stack_test_utils", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -172,7 +163,6 @@ py_library( + "//federated_language/types:placements", + "//federated_language/types:type_conversions", + "//federated_language/types:typed_object", +- "@pypi//attrs", + ], + ) + +@@ -190,9 +180,6 @@ py_test( + "//federated_language/context_stack:context_stack_impl", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//attrs", +- "@pypi//numpy", + ], + ) + +@@ -222,7 +209,5 @@ py_test( + "//federated_language/context_stack:context_stack_impl", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) +diff --git federated_language/program/BUILD federated_language/program/BUILD +index 6e0df6d..f7a54a0 100644 +--- federated_language/program/BUILD ++++ federated_language/program/BUILD +@@ -58,8 +58,6 @@ py_test( + "//federated_language/context_stack:context_stack_impl", + "//federated_language/types:computation_types", + "//federated_language/types:placements", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -69,7 +67,6 @@ py_library( + deps = [ + ":release_manager", + ":value_reference", +- "@pypi//absl_py", + ], + ) + +@@ -79,9 +76,6 @@ py_test( + deps = [ + ":logging_release_manager", + ":program_test_utils", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -100,9 +94,6 @@ py_test( + deps = [ + ":memory_release_manager", + ":program_test_utils", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -119,10 +110,7 @@ py_library( + py_test( + name = "program_state_manager_test", + srcs = ["program_state_manager_test.py"], +- deps = [ +- ":program_state_manager", +- "@pypi//absl_py", +- ], ++ deps = [":program_state_manager"], + ) + + py_library( +@@ -134,9 +122,6 @@ py_library( + "//federated_language/common_libs:py_typecheck", + "//federated_language/common_libs:serializable", + "//federated_language/types:computation_types", +- "@pypi//attrs", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -147,8 +132,6 @@ py_library( + ":structure_utils", + ":value_reference", + "//federated_language/common_libs:py_typecheck", +- "@pypi//attrs", +- "@pypi//dm_tree", + ], + ) + +@@ -158,9 +141,6 @@ py_test( + deps = [ + ":program_test_utils", + ":release_manager", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -182,19 +162,13 @@ py_test( + ":program_test_utils", + ":serialization_utils", + "//federated_language/types:computation_types", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + + py_library( + name = "structure_utils", + srcs = ["structure_utils.py"], +- deps = [ +- "//federated_language/common_libs:py_typecheck", +- "@pypi//attrs", +- "@pypi//dm_tree", +- ], ++ deps = ["//federated_language/common_libs:py_typecheck"], + ) + + py_test( +@@ -203,9 +177,6 @@ py_test( + deps = [ + ":program_test_utils", + ":structure_utils", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -216,7 +187,6 @@ py_library( + ":structure_utils", + "//federated_language/types:computation_types", + "//federated_language/types:typed_object", +- "@pypi//numpy", + ], + ) + +@@ -226,8 +196,5 @@ py_test( + deps = [ + ":program_test_utils", + ":value_reference", +- "@pypi//absl_py", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) +diff --git federated_language/test/BUILD federated_language/test/BUILD +index ab52bca..39593e0 100644 +--- federated_language/test/BUILD ++++ federated_language/test/BUILD +@@ -33,6 +33,5 @@ py_test( + "//federated_language/federated_context:federated_computation", + "//federated_language/federated_context:intrinsics", + "//federated_language/types:placements", +- "@pypi//absl_py", + ], + ) +diff --git federated_language/types/BUILD federated_language/types/BUILD +index e209fb2..e102338 100644 +--- federated_language/types/BUILD ++++ federated_language/types/BUILD +@@ -39,7 +39,6 @@ py_test( + deps = [ + ":array_shape", + "//federated_language/proto:array_py_pb2", +- "@pypi//absl_py", + ], + ) + +@@ -52,9 +51,6 @@ py_library( + ":placements", + "//federated_language/common_libs:py_typecheck", + "//federated_language/common_libs:structure", +- "@pypi//attrs", +- "@pypi//numpy", +- "@pypi//typing_extensions", + ], + ) + +@@ -80,31 +76,19 @@ py_test( + ":placements", + "//federated_language/common_libs:golden", + "//federated_language/common_libs:structure", +- "@pypi//absl_py", +- "@pypi//attrs", +- "@pypi//numpy", + ], + ) + + py_library( + name = "dtype_utils", + srcs = ["dtype_utils.py"], +- deps = [ +- "//federated_language/proto:data_type_py_pb2", +- "@pypi//ml_dtypes", +- "@pypi//numpy", +- ], ++ deps = ["//federated_language/proto:data_type_py_pb2"], + ) + + py_test( + name = "dtype_utils_test", + srcs = ["dtype_utils_test.py"], +- deps = [ +- ":dtype_utils", +- "@pypi//absl_py", +- "@pypi//ml_dtypes", +- "@pypi//numpy", +- ], ++ deps = [":dtype_utils"], + ) + + py_library( +@@ -116,10 +100,7 @@ py_test( + name = "placements_test", + size = "small", + srcs = ["placements_test.py"], +- deps = [ +- ":placements", +- "@pypi//absl_py", +- ], ++ deps = [":placements"], + ) + + py_library( +@@ -133,8 +114,6 @@ py_library( + ":type_transformations", + "//federated_language/common_libs:py_typecheck", + "//federated_language/common_libs:structure", +- "@pypi//ml_dtypes", +- "@pypi//numpy", + ], + ) + +@@ -147,9 +126,6 @@ py_test( + ":placements", + ":type_analysis", + "//federated_language/common_libs:structure", +- "@pypi//absl_py", +- "@pypi//ml_dtypes", +- "@pypi//numpy", + ], + ) + +@@ -162,9 +138,6 @@ py_library( + ":typed_object", + "//federated_language/common_libs:py_typecheck", + "//federated_language/common_libs:structure", +- "@pypi//attrs", +- "@pypi//dm_tree", +- "@pypi//numpy", + ], + ) + +@@ -178,9 +151,6 @@ py_test( + ":type_conversions", + ":typed_object", + "//federated_language/common_libs:structure", +- "@pypi//absl_py", +- "@pypi//attrs", +- "@pypi//numpy", + ], + ) + +@@ -197,8 +167,6 @@ py_test( + deps = [ + ":computation_types", + ":type_factory", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -227,8 +195,6 @@ py_test( + ":type_serialization", + "//federated_language/proto:computation_py_pb2", + "//federated_language/proto:data_type_py_pb2", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + +@@ -255,8 +221,6 @@ py_test( + ":computation_types", + ":placements", + ":type_transformations", +- "@pypi//absl_py", +- "@pypi//numpy", + ], + ) + diff --git a/third_party/federated_language/structure_visibility.patch b/third_party/federated_language/structure_visibility.patch new file mode 100644 index 0000000000..d608b6fd3d --- /dev/null +++ b/third_party/federated_language/structure_visibility.patch @@ -0,0 +1,12 @@ +diff --git federated_language/common_libs/BUILD federated_language/common_libs/BUILD +index 730c71e..6e584cf 100644 +--- federated_language/common_libs/BUILD ++++ federated_language/common_libs/BUILD +@@ -80,6 +80,7 @@ py_library( + py_library( + name = "structure", + srcs = ["structure.py"], ++ visibility = ["//visibility:public"], + deps = [":py_typecheck"], + ) + diff --git a/tools/python_package/BUILD b/tools/python_package/BUILD index 784754310c..4cafc6c6ad 100644 --- a/tools/python_package/BUILD +++ b/tools/python_package/BUILD @@ -36,12 +36,9 @@ sh_binary( "//tensorflow_federated/python/core/environments/xla_backend", "//tensorflow_federated/python/core/impl", "//tensorflow_federated/python/core/impl/compiler", - "//tensorflow_federated/python/core/impl/computation", - "//tensorflow_federated/python/core/impl/context_stack", "//tensorflow_federated/python/core/impl/execution_contexts", "//tensorflow_federated/python/core/impl/executor_stacks", "//tensorflow_federated/python/core/impl/executors", - "//tensorflow_federated/python/core/impl/federated_context", ], )