Skip to content

Commit

Permalink
Cleanup old Dataset iterator usage.
Browse files Browse the repository at this point in the history
Necessary for TF 2.0 compatibility.

PiperOrigin-RevId: 258869411
  • Loading branch information
ZacharyGarrett authored and tensorflower-gardener committed Jul 18, 2019
1 parent 46504db commit 2760e50
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 135 deletions.
1 change: 1 addition & 0 deletions tensorflow_federated/proto/v0/computation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ message TensorFlow {
// A representation of a sequence declared in the type signature.
message SequenceBinding {
oneof binding {
// WARNING: `iterator_string_handle_name` IS NO LONGER SUPPORTED.
// The name of the placeholder tensor that represents the string handle
// of the data set iterator associated with the sequence.
string iterator_string_handle_name = 1 [deprecated = true];
Expand Down
1 change: 0 additions & 1 deletion tensorflow_federated/python/core/impl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,6 @@ py_library(
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:serialization_utils",
"//tensorflow_federated/python/core/api:computation_types",
"//tensorflow_federated/python/core/api:typed_object",
"//tensorflow_federated/python/tensorflow_libs:graph_merge",
],
)
Expand Down
99 changes: 2 additions & 97 deletions tensorflow_federated/python/core/impl/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import collections
import functools
import itertools
import logging

import attr
import numpy as np
Expand All @@ -35,7 +34,6 @@
from tensorflow_federated.python.common_libs import anonymous_tuple
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import typed_object
from tensorflow_federated.python.core.impl import dtype_utils
from tensorflow_federated.python.core.impl import function_utils
from tensorflow_federated.python.core.impl import type_utils
Expand Down Expand Up @@ -357,88 +355,9 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph):
'graph.'.format(repr(parameter_type)))


class OneShotDataset(typed_object.TypedObject):
"""A factory of `tf.data.Dataset`-like objects based on a no-argument lambda.
This factory supports the same methods as the data sets constructed by the
lambda. Upon invocation, it constructs a new data set by invoking the lambda,
then forwards the call to that data set. A new data set is created per call.
"""

# TODO(b/129956296): Eventually delete this deprecated class.

def __init__(self, fn, element_type):
"""Constructs this factory from `fn`.
Args:
fn: A no-argument callable that creates instances of `tf.data.Dataset`.
element_type: The type of elements.
"""
# TODO(b/131426323) Possibly reuse TensorFlow's @deprecation.deprecated()
# here if possible.
logging.warning('OneShotDataset is deprecated.')
py_typecheck.check_type(element_type, computation_types.Type)
self._type_signature = computation_types.SequenceType(element_type)
self._fn = fn

@property
def type_signature(self):
"""Returns the TFF type of this object (an instance of `tff.Type`)."""
return self._type_signature

def __getattr__(self, name):
return getattr(self._fn(), name)


# TODO(b/129956296): Eventually delete this deprecated declaration.
DATASET_REPRESENTATION_TYPES = (tf.data.Dataset, tf.compat.v1.data.Dataset,
tf.compat.v2.data.Dataset, OneShotDataset)


def make_dataset_from_string_handle(handle, type_spec):
"""Constructs a `tf.data.Dataset` from a string handle tensor and type spec.
Args:
handle: The tensor that represents the string handle.
type_spec: The type spec of elements of the data set, either an instance of
`types.Type` or something convertible to it.
Returns:
A corresponding instance of `tf.data.Dataset`.
"""
# TODO(b/129956296): Eventually delete this deprecated code path.

type_spec = computation_types.to_type(type_spec)
tf_dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes(type_spec)

def make(handle=handle, tf_dtypes=tf_dtypes, shapes=shapes):
"""An embedded no-argument function that constructs the data set on-demand.
This is invoked by `OneShotDataset` on each access to the data set argument
passed to the body of the TF computation to ensure that the iterators and
tje map are constructed in the appropriate context (e.g., in a defun).
Args:
handle: Captured from the local (above).
tf_dtypes: Captured from the local (above).
shapes: Captured from the local (above).
Returns:
An instance of `tf.data.Dataset`.
"""
with handle.graph.as_default():
it = tf.data.Iterator.from_string_handle(handle, tf_dtypes, shapes)
# In order to convert an iterator into something that looks like a data
# set, we create a dummy data set that consists of an infinite sequence
# of zeroes, and filter it through a map function that invokes
# 'it.get_next()' for each of those zeroes.
# TODO(b/113112108): Possibly replace this with something more canonical
# if and when we can find adequate support for abstractly defined data
# sets (at the moment of this writing it does not appear to exist yet).
return tf.data.Dataset.range(1).repeat().map(lambda _: it.get_next())

# NOTE: To revert to the old behavior, simply invoke `make()` here directly.
return OneShotDataset(make, type_spec)
tf.compat.v2.data.Dataset)


def make_dataset_from_variant_tensor(variant_tensor, type_spec):
Expand Down Expand Up @@ -572,16 +491,6 @@ def _get_bindings_for_elements(name_value_pairs, graph, type_fn):
pb.TensorFlow.Binding(
sequence=pb.TensorFlow.SequenceBinding(
variant_tensor_name=variant_tensor.name)))
elif isinstance(result, OneShotDataset):
# TODO(b/129956296): Eventually delete this deprecated code path.
element_type = type_utils.tf_dtypes_and_shapes_to_type(
tf.compat.v1.data.get_output_types(result),
tf.compat.v1.data.get_output_shapes(result))
handle_name = result.make_one_shot_iterator().string_handle().name
return (computation_types.SequenceType(element_type),
pb.TensorFlow.Binding(
sequence=pb.TensorFlow.SequenceBinding(
iterator_string_handle_name=handle_name)))
else:
raise TypeError('Cannot capture a result of an unsupported type {}.'.format(
py_typecheck.type_string(type(result))))
Expand Down Expand Up @@ -760,11 +669,7 @@ def assemble_result_from_graph(type_spec, binding, output_map):
'Expected a sequence binding, found {}.'.format(binding_oneof))
else:
sequence_oneof = binding.sequence.WhichOneof('binding')
if sequence_oneof == 'iterator_string_handle_name':
# TODO(b/129956296): Eventually delete this deprecated code path.
handle = output_map[binding.sequence.iterator_string_handle_name]
return make_dataset_from_string_handle(handle, type_spec.element)
elif sequence_oneof == 'variant_tensor_name':
if sequence_oneof == 'variant_tensor_name':
variant_tensor = output_map[binding.sequence.variant_tensor_name]
return make_dataset_from_variant_tensor(variant_tensor,
type_spec.element)
Expand Down
46 changes: 9 additions & 37 deletions tensorflow_federated/python/core/impl/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,22 +471,20 @@ def test_assemble_result_from_graph_with_sequence_of_odicts(self):
type_spec = computation_types.SequenceType(
collections.OrderedDict([('X', tf.int32), ('Y', tf.int32)]))
binding = pb.TensorFlow.Binding(
sequence=pb.TensorFlow.SequenceBinding(
iterator_string_handle_name='foo'))
sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='foo'))
data_set = tf.data.Dataset.from_tensors({
'X': tf.constant(1),
'Y': tf.constant(2)
})
it = data_set.make_one_shot_iterator()
output_map = {'foo': it.string_handle()}
output_map = {'foo': tf.data.experimental.to_variant(data_set)}
result = graph_utils.assemble_result_from_graph(type_spec, binding,
output_map)
self.assertIsInstance(result, graph_utils.DATASET_REPRESENTATION_TYPES)
self.assertEqual(
str(result.output_types),
str(tf.compat.v1.data.get_output_types(result)),
'OrderedDict([(\'X\', tf.int32), (\'Y\', tf.int32)])')
self.assertEqual(
str(result.output_shapes),
str(tf.compat.v1.data.get_output_shapes(result)),
'OrderedDict([(\'X\', TensorShape([])), (\'Y\', TensorShape([]))])')

@test.graph_mode_test
Expand All @@ -495,21 +493,20 @@ def test_assemble_result_from_graph_with_sequence_of_namedtuples(self):
type_spec = computation_types.SequenceType(
named_tuple_type(tf.int32, tf.int32))
binding = pb.TensorFlow.Binding(
sequence=pb.TensorFlow.SequenceBinding(
iterator_string_handle_name='foo'))
sequence=pb.TensorFlow.SequenceBinding(variant_tensor_name='foo'))
data_set = tf.data.Dataset.from_tensors({
'X': tf.constant(1),
'Y': tf.constant(2)
})
it = data_set.make_one_shot_iterator()
output_map = {'foo': it.string_handle()}
output_map = {'foo': tf.data.experimental.to_variant(data_set)}
result = graph_utils.assemble_result_from_graph(type_spec, binding,
output_map)
self.assertIsInstance(result, graph_utils.DATASET_REPRESENTATION_TYPES)
self.assertEqual(
str(result.output_types), 'TestNamedTuple(X=tf.int32, Y=tf.int32)')
str(tf.compat.v1.data.get_output_types(result)),
'TestNamedTuple(X=tf.int32, Y=tf.int32)')
self.assertEqual(
str(result.output_shapes),
str(tf.compat.v1.data.get_output_shapes(result)),
'TestNamedTuple(X=TensorShape([]), Y=TensorShape([]))')

def test_make_dummy_element_TensorType(self):
Expand Down Expand Up @@ -909,31 +906,6 @@ def test_make_data_set_from_elements_with_just_one_batch(self):
'x': np.array([1])
}], [('x', computation_types.TensorType(tf.int32, tf.TensorShape([None])))])

def test_one_shot_dataset_with_defuns(self):
with tf.Graph().as_default() as graph:
ds1 = tf.data.Dataset.from_tensor_slices([1, 1])
it1 = ds1.make_one_shot_iterator()
sh1 = it1.string_handle()

dtype = tf.int32
shape = tf.TensorShape([])

def make():
it2 = tf.data.Iterator.from_string_handle(sh1, dtype, shape)
return tf.data.Dataset.range(1).repeat().map(lambda _: it2.get_next())

ds2 = graph_utils.OneShotDataset(
make, computation_types.TensorType(dtype, shape))

@tf.function
def foo():
return ds2.reduce(np.int32(0), lambda x, y: x + y)

result = foo()

with tf.compat.v1.Session(graph=graph) as sess:
self.assertEqual(sess.run(result), 2)

def test_make_dataset_from_variant_tensor_constructs_dataset(self):
with tf.Graph().as_default():
ds = graph_utils.make_dataset_from_variant_tensor(
Expand Down

0 comments on commit 2760e50

Please sign in to comment.