Skip to content

Commit

Permalink
Create computation_factory module to reduce duplicate lambda constr…
Browse files Browse the repository at this point in the history
…uction functions.

PiperOrigin-RevId: 309130728
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Apr 30, 2020
1 parent 4bf48d0 commit bf2224b
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 62 deletions.
50 changes: 38 additions & 12 deletions tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,33 +135,27 @@ py_test(
)

py_library(
name = "tensorflow_computation_factory",
srcs = ["tensorflow_computation_factory.py"],
name = "computation_factory",
srcs = ["computation_factory.py"],
srcs_version = "PY3",
deps = [
":type_factory",
":type_serialization",
":type_transformations",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:anonymous_tuple",
"//tensorflow_federated/python/common_libs:serialization_utils",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/impl:type_utils",
"//tensorflow_federated/python/core/impl/utils:tensorflow_utils",
],
)

py_test(
name = "tensorflow_computation_factory_test",
srcs = ["tensorflow_computation_factory_test.py"],
name = "computation_factory_test",
srcs = ["computation_factory_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":tensorflow_computation_factory",
":test_utils",
":computation_factory",
":type_factory",
":type_serialization",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:anonymous_tuple",
"//tensorflow_federated/python/core/api:computation_types",
],
)
Expand Down Expand Up @@ -236,6 +230,38 @@ py_test(
],
)

py_library(
name = "tensorflow_computation_factory",
srcs = ["tensorflow_computation_factory.py"],
srcs_version = "PY3",
deps = [
":type_serialization",
":type_transformations",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:anonymous_tuple",
"//tensorflow_federated/python/common_libs:serialization_utils",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/impl:type_utils",
"//tensorflow_federated/python/core/impl/utils:tensorflow_utils",
],
)

py_test(
name = "tensorflow_computation_factory_test",
srcs = ["tensorflow_computation_factory_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":tensorflow_computation_factory",
":test_utils",
":type_factory",
":type_serialization",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:anonymous_tuple",
"//tensorflow_federated/python/core/api:computation_types",
],
)

py_library(
name = "test_utils",
testonly = True,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Lint as: python3
# Copyright 2020, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library implementing common `pb.Computation` structures."""

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.impl.compiler import type_factory
from tensorflow_federated.python.core.impl.compiler import type_serialization


def create_lambda_empty_tuple() -> pb.Computation:
"""Returns a lambda computation returning an empty tuple.
Has the type signature:
( -> <>)
Returns:
An instance of `pb.Computation`.
"""
result_type = computation_types.NamedTupleType([])
type_signature = computation_types.FunctionType(None, result_type)
result = pb.Computation(
type=type_serialization.serialize_type(result_type),
tuple=pb.Tuple(element=[]))
fn = pb.Lambda(parameter_name=None, result=result)
# We are unpacking the lambda argument here because `lambda` is a reserved
# keyword in Python, but it is also the name of the parameter for a
# `pb.Computation`.
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
return pb.Computation(
type=type_serialization.serialize_type(type_signature), **{'lambda': fn}) # pytype: disable=wrong-keyword-args


def create_lambda_identity(type_spec) -> pb.Computation:
"""Returns a lambda computation representing an identity function.
Has the type signature:
(T -> T)
Args:
type_spec: A type convertible to instance of `computation_types.Type` via
`computation_types.to_type`.
Returns:
An instance of `pb.Computation`.
"""
type_spec = computation_types.to_type(type_spec)
type_signature = type_factory.unary_op(type_spec)
result = pb.Computation(
type=type_serialization.serialize_type(type_spec),
reference=pb.Reference(name='a'))
fn = pb.Lambda(parameter_name='a', result=result)
# We are unpacking the lambda argument here because `lambda` is a reserved
# keyword in Python, but it is also the name of the parameter for a
# `pb.Computation`.
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
return pb.Computation(
type=type_serialization.serialize_type(type_signature), **{'lambda': fn}) # pytype: disable=wrong-keyword-args
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Lint as: python3
# Copyright 2020, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
import tensorflow as tf

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.impl.compiler import computation_factory
from tensorflow_federated.python.core.impl.compiler import type_factory
from tensorflow_federated.python.core.impl.compiler import type_serialization


class CreateLambdaEmptyTupleTest(absltest.TestCase):

def test_returns_coputation(self):
proto = computation_factory.create_lambda_empty_tuple()

self.assertIsInstance(proto, pb.Computation)
actual_type = type_serialization.deserialize_type(proto.type)
expected_type = computation_types.FunctionType(None, [])
self.assertEqual(actual_type, expected_type)


class CreateLambdaIdentityTest(absltest.TestCase):

def test_returns_computation_int(self):
type_signature = computation_types.TensorType(tf.int32)

proto = computation_factory.create_lambda_identity(type_signature)

self.assertIsInstance(proto, pb.Computation)
actual_type = type_serialization.deserialize_type(proto.type)
expected_type = type_factory.unary_op(type_signature)
self.assertEqual(actual_type, expected_type)

def test_returns_computation_tuple_unnamed(self):
type_signature = computation_types.NamedTupleType([tf.int32, tf.float32])

proto = computation_factory.create_lambda_identity(type_signature)

self.assertIsInstance(proto, pb.Computation)
actual_type = type_serialization.deserialize_type(proto.type)
expected_type = type_factory.unary_op(type_signature)
self.assertEqual(actual_type, expected_type)

def test_returns_computation_tuple_named(self):
type_signature = computation_types.NamedTupleType([('a', tf.int32),
('b', tf.float32)])

proto = computation_factory.create_lambda_identity(type_signature)

self.assertIsInstance(proto, pb.Computation)
actual_type = type_serialization.deserialize_type(proto.type)
expected_type = type_factory.unary_op(type_signature)
self.assertEqual(actual_type, expected_type)

def test_returns_computation_sequence(self):
type_signature = computation_types.SequenceType(tf.int32)

proto = computation_factory.create_lambda_identity(type_signature)

self.assertIsInstance(proto, pb.Computation)
actual_type = type_serialization.deserialize_type(proto.type)
expected_type = type_factory.unary_op(type_signature)
self.assertEqual(actual_type, expected_type)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tensorflow_federated.python.core.impl.utils import tensorflow_utils


def create_constant(scalar_value, type_spec):
def create_constant(scalar_value, type_spec) -> pb.Computation:
"""Returns a tensorflow computation returning a constant `scalar_value`.
Has the type signature:
Expand Down Expand Up @@ -108,7 +108,7 @@ def _create_result_tensor(type_spec, scalar_value):
tensorflow=tensorflow)


def create_empty_tuple():
def create_empty_tuple() -> pb.Computation:
"""Returns a tensorflow computation returning an empty tuple.
Has the type signature:
Expand All @@ -133,7 +133,7 @@ def create_empty_tuple():
tensorflow=tensorflow)


def create_identity(type_spec):
def create_identity(type_spec) -> pb.Computation:
"""Returns a tensorflow computation representing an identity function.
Has the type signature:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_returns_computation_tuple_nested(self):

def test_raises_type_error_with_non_scalar_value(self):
value = np.zeros([1])
type_signature = tf.int32
type_signature = computation_types.TensorType(tf.int32)

with self.assertRaises(TypeError):
tensorflow_computation_factory.create_constant(value, type_signature)
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_returns_coputation(self):
class CreateIdentityTest(parameterized.TestCase):

def test_returns_computation_int(self):
type_signature = tf.int32
type_signature = computation_types.TensorType(tf.int32)

proto = tensorflow_computation_factory.create_identity(type_signature)

Expand All @@ -161,7 +161,7 @@ def test_returns_computation_int(self):
self.assertEqual(actual_value, expected_value)

def test_returns_computation_tuple_unnamed(self):
type_signature = [tf.int32, tf.float32]
type_signature = computation_types.NamedTupleType([tf.int32, tf.float32])

proto = tensorflow_computation_factory.create_identity(type_signature)

Expand All @@ -174,7 +174,8 @@ def test_returns_computation_tuple_unnamed(self):
self.assertEqual(actual_value, expected_value)

def test_returns_computation_tuple_named(self):
type_signature = [('a', tf.int32), ('b', tf.float32)]
type_signature = computation_types.NamedTupleType([('a', tf.int32),
('b', tf.float32)])

proto = tensorflow_computation_factory.create_identity(type_signature)

Expand Down
3 changes: 2 additions & 1 deletion tensorflow_federated/python/core/impl/executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ py_library(
"//tensorflow_federated/python/common_libs:tracing",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/impl:type_utils",
"//tensorflow_federated/python/core/impl/compiler:computation_factory",
"//tensorflow_federated/python/core/impl/compiler:intrinsic_defs",
"//tensorflow_federated/python/core/impl/compiler:placement_literals",
"//tensorflow_federated/python/core/impl/compiler:type_factory",
"//tensorflow_federated/python/core/impl/compiler:type_serialization",
],
)

Expand Down Expand Up @@ -366,6 +366,7 @@ py_library(
"//tensorflow_federated/python/common_libs:serialization_utils",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/impl:reference_executor",
"//tensorflow_federated/python/core/impl/compiler:computation_factory",
"//tensorflow_federated/python/core/impl/compiler:intrinsic_defs",
"//tensorflow_federated/python/core/impl/compiler:placement_literals",
"//tensorflow_federated/python/core/impl/compiler:tensorflow_computation_factory",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from tensorflow_federated.python.common_libs import tracing
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.impl import type_utils
from tensorflow_federated.python.core.impl.compiler import computation_factory
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
from tensorflow_federated.python.core.impl.compiler import placement_literals
from tensorflow_federated.python.core.impl.compiler import type_factory
from tensorflow_federated.python.core.impl.compiler import type_serialization
from tensorflow_federated.python.core.impl.executors import executor_base
from tensorflow_federated.python.core.impl.executors import executor_utils
from tensorflow_federated.python.core.impl.executors import executor_value_base
Expand Down Expand Up @@ -117,22 +117,6 @@ async def _compute_tuple(anon_tuple):
return await _compute_tuple(self._value)


def _create_lambda_identity_comp(type_spec):
"""Returns a `pb.Computation` representing an identity function."""
py_typecheck.check_type(type_spec, computation_types.Type)
type_signature = type_serialization.serialize_type(
type_factory.unary_op(type_spec))
result = pb.Computation(
type=type_serialization.serialize_type(type_spec),
reference=pb.Reference(name='x'))
fn = pb.Lambda(parameter_name='x', result=result)
# We are unpacking the lambda argument here because `lambda` is a reserved
# keyword in Python, but it is also the name of the parameter for a
# `pb.Computation`.
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#keyword-conflicts
return pb.Computation(type=type_signature, **{'lambda': fn}) # pytype: disable=wrong-keyword-args


class ComposingExecutor(executor_base.Executor):
"""An executor composed of subordinate executors that manage disjoint scopes.
Expand Down Expand Up @@ -382,7 +366,7 @@ async def _compute_intrinsic_federated_aggregate(self, arg):
val = arg.internal_representation[0]
py_typecheck.check_type(val, list)
py_typecheck.check_len(val, len(self._child_executors))
identity_report = _create_lambda_identity_comp(zero_type)
identity_report = computation_factory.create_lambda_identity(zero_type)
identity_report_type = type_factory.unary_op(zero_type)
aggr_type = computation_types.FunctionType(
computation_types.NamedTupleType([
Expand Down Expand Up @@ -562,7 +546,8 @@ async def _compute_intrinsic_federated_sum(self, arg):
executor_utils.embed_tf_binary_operator(self, arg.type_signature.member,
tf.add),
self.create_value(
_create_lambda_identity_comp(arg.type_signature.member),
computation_factory.create_lambda_identity(
arg.type_signature.member),
type_factory.unary_op(arg.type_signature.member))
]))
aggregate_args = await self.create_tuple([arg, zero, plus, plus, identity])
Expand Down
Loading

0 comments on commit bf2224b

Please sign in to comment.