forked from google-parfait/tensorflow-federated
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create
computation_factory
module to reduce duplicate lambda constr…
…uction functions. PiperOrigin-RevId: 309130728
- Loading branch information
1 parent
4bf48d0
commit bf2224b
Showing
8 changed files
with
212 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
tensorflow_federated/python/core/impl/compiler/computation_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Lint as: python3 | ||
# Copyright 2020, The TensorFlow Federated Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Library implementing common `pb.Computation` structures.""" | ||
|
||
from tensorflow_federated.proto.v0 import computation_pb2 as pb | ||
from tensorflow_federated.python.core.api import computation_types | ||
from tensorflow_federated.python.core.impl.compiler import type_factory | ||
from tensorflow_federated.python.core.impl.compiler import type_serialization | ||
|
||
|
||
def create_lambda_empty_tuple() -> pb.Computation: | ||
"""Returns a lambda computation returning an empty tuple. | ||
Has the type signature: | ||
( -> <>) | ||
Returns: | ||
An instance of `pb.Computation`. | ||
""" | ||
result_type = computation_types.NamedTupleType([]) | ||
type_signature = computation_types.FunctionType(None, result_type) | ||
result = pb.Computation( | ||
type=type_serialization.serialize_type(result_type), | ||
tuple=pb.Tuple(element=[])) | ||
fn = pb.Lambda(parameter_name=None, result=result) | ||
# We are unpacking the lambda argument here because `lambda` is a reserved | ||
# keyword in Python, but it is also the name of the parameter for a | ||
# `pb.Computation`. | ||
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts | ||
return pb.Computation( | ||
type=type_serialization.serialize_type(type_signature), **{'lambda': fn}) # pytype: disable=wrong-keyword-args | ||
|
||
|
||
def create_lambda_identity(type_spec) -> pb.Computation: | ||
"""Returns a lambda computation representing an identity function. | ||
Has the type signature: | ||
(T -> T) | ||
Args: | ||
type_spec: A type convertible to instance of `computation_types.Type` via | ||
`computation_types.to_type`. | ||
Returns: | ||
An instance of `pb.Computation`. | ||
""" | ||
type_spec = computation_types.to_type(type_spec) | ||
type_signature = type_factory.unary_op(type_spec) | ||
result = pb.Computation( | ||
type=type_serialization.serialize_type(type_spec), | ||
reference=pb.Reference(name='a')) | ||
fn = pb.Lambda(parameter_name='a', result=result) | ||
# We are unpacking the lambda argument here because `lambda` is a reserved | ||
# keyword in Python, but it is also the name of the parameter for a | ||
# `pb.Computation`. | ||
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts | ||
return pb.Computation( | ||
type=type_serialization.serialize_type(type_signature), **{'lambda': fn}) # pytype: disable=wrong-keyword-args |
82 changes: 82 additions & 0 deletions
82
tensorflow_federated/python/core/impl/compiler/computation_factory_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Lint as: python3 | ||
# Copyright 2020, The TensorFlow Federated Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import absltest | ||
import tensorflow as tf | ||
|
||
from tensorflow_federated.proto.v0 import computation_pb2 as pb | ||
from tensorflow_federated.python.core.api import computation_types | ||
from tensorflow_federated.python.core.impl.compiler import computation_factory | ||
from tensorflow_federated.python.core.impl.compiler import type_factory | ||
from tensorflow_federated.python.core.impl.compiler import type_serialization | ||
|
||
|
||
class CreateLambdaEmptyTupleTest(absltest.TestCase): | ||
|
||
def test_returns_coputation(self): | ||
proto = computation_factory.create_lambda_empty_tuple() | ||
|
||
self.assertIsInstance(proto, pb.Computation) | ||
actual_type = type_serialization.deserialize_type(proto.type) | ||
expected_type = computation_types.FunctionType(None, []) | ||
self.assertEqual(actual_type, expected_type) | ||
|
||
|
||
class CreateLambdaIdentityTest(absltest.TestCase): | ||
|
||
def test_returns_computation_int(self): | ||
type_signature = computation_types.TensorType(tf.int32) | ||
|
||
proto = computation_factory.create_lambda_identity(type_signature) | ||
|
||
self.assertIsInstance(proto, pb.Computation) | ||
actual_type = type_serialization.deserialize_type(proto.type) | ||
expected_type = type_factory.unary_op(type_signature) | ||
self.assertEqual(actual_type, expected_type) | ||
|
||
def test_returns_computation_tuple_unnamed(self): | ||
type_signature = computation_types.NamedTupleType([tf.int32, tf.float32]) | ||
|
||
proto = computation_factory.create_lambda_identity(type_signature) | ||
|
||
self.assertIsInstance(proto, pb.Computation) | ||
actual_type = type_serialization.deserialize_type(proto.type) | ||
expected_type = type_factory.unary_op(type_signature) | ||
self.assertEqual(actual_type, expected_type) | ||
|
||
def test_returns_computation_tuple_named(self): | ||
type_signature = computation_types.NamedTupleType([('a', tf.int32), | ||
('b', tf.float32)]) | ||
|
||
proto = computation_factory.create_lambda_identity(type_signature) | ||
|
||
self.assertIsInstance(proto, pb.Computation) | ||
actual_type = type_serialization.deserialize_type(proto.type) | ||
expected_type = type_factory.unary_op(type_signature) | ||
self.assertEqual(actual_type, expected_type) | ||
|
||
def test_returns_computation_sequence(self): | ||
type_signature = computation_types.SequenceType(tf.int32) | ||
|
||
proto = computation_factory.create_lambda_identity(type_signature) | ||
|
||
self.assertIsInstance(proto, pb.Computation) | ||
actual_type = type_serialization.deserialize_type(proto.type) | ||
expected_type = type_factory.unary_op(type_signature) | ||
self.assertEqual(actual_type, expected_type) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.