diff --git a/RELEASE.md b/RELEASE.md index acec8b2a81..244724e398 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -18,6 +18,7 @@ and this project adheres to histograms, and a boolean indicating whether Laplace noise was used. * Added some TFF executor classes to the public API (CPPExecutorFactory, ResourceManagingExecutorFactory, RemoteExecutor, RemoteExecutorGrpcStub). +* Added support for `bfloat16` dtypes from the `ml_dtypes` package. ### Fixed diff --git a/pyproject.toml b/pyproject.toml index bcd14f72b9..e4add3ef6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ 'grpcio~=1.46', 'jaxlib==0.4.14', 'jax==0.4.14', + 'ml_dtypes>=0.2.0,==0.2.*', 'numpy~=1.25', 'portpicker~=1.6', 'scipy~=1.9.3', diff --git a/requirements.txt b/requirements.txt index cc247f214d..50788d552e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,7 @@ google-vizier==0.1.11 grpcio~=1.46 jaxlib==0.4.14 jax==0.4.14 +ml_dtypes>=0.2.0,==0.2.* numpy~=1.25 portpicker~=1.6 scipy~=1.9.3 diff --git a/tensorflow_federated/cc/core/impl/executors/BUILD b/tensorflow_federated/cc/core/impl/executors/BUILD index 83f3a46852..81377857fc 100644 --- a/tensorflow_federated/cc/core/impl/executors/BUILD +++ b/tensorflow_federated/cc/core/impl/executors/BUILD @@ -1228,6 +1228,7 @@ cc_library( "@org_tensorflow//tensorflow/compiler/xla:literal", "@org_tensorflow//tensorflow/compiler/xla:shape_util", "@org_tensorflow//tensorflow/compiler/xla:types", + "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc", ], ) diff --git a/tensorflow_federated/cc/core/impl/executors/array_test_utils.h b/tensorflow_federated/cc/core/impl/executors/array_test_utils.h index 52bdf987d5..d3080c2247 100644 --- a/tensorflow_federated/cc/core/impl/executors/array_test_utils.h +++ b/tensorflow_federated/cc/core/impl/executors/array_test_utils.h @@ -124,6 +124,30 @@ inline absl::StatusOr CreateArray( return array_pb; } +// Overload for Eigen::bfloat16. +inline absl::StatusOr CreateArray( + v0::DataType dtype, v0::ArrayShape shape_pb, + std::initializer_list values) { + v0::Array array_pb; + array_pb.set_dtype(dtype); + array_pb.mutable_shape()->Swap(&shape_pb); + switch (dtype) { + case v0::DataType::DT_BFLOAT16: { + auto size = values.size(); + array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size); + for (auto element : values) { + array_pb.mutable_bfloat16_list()->mutable_value()->AddAlreadyReserved( + Eigen::numext::bit_cast(element)); + } + break; + } + default: + return absl::UnimplementedError( + absl::StrCat("Unexpected DataType found:", dtype)); + } + return array_pb; +} + // Overload for complex. template inline absl::StatusOr CreateArray( diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc index d2f10a7fea..183669ee50 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc @@ -89,6 +89,17 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField return Eigen::numext::bit_cast(static_cast(x)); }); } +// Overload for Eigen::bfloat16. +static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, + Eigen::bfloat16* dest) { + // Values of dtype ml_dtypes.bfloat16 are packed to and unpacked from a + // protobuf field of type int32 using the following logic in order to maintain + // compatibility with how other external environments (e.g. TensorFlow, Jax) + // represent values of ml_dtypes.bfloat16. + std::transform(src.begin(), src.end(), dest, [](int x) -> Eigen::bfloat16 { + return Eigen::numext::bit_cast(static_cast(x)); + }); +} // Overload for complex. template @@ -186,6 +197,14 @@ absl::StatusOr TensorFromArray(const v0::Array& array_pb) { tensor.flat().data()); return tensor; } + case v0::Array::kBfloat16List: { + tensorflow::Tensor tensor( + tensorflow::DataTypeToEnum::value, + TFF_TRY(TensorShapeFromArrayShape(array_pb.shape()))); + CopyFromRepeatedField(array_pb.bfloat16_list().value(), + tensor.flat().data()); + return tensor; + } case v0::Array::kFloat32List: { tensorflow::Tensor tensor( tensorflow::DataTypeToEnum::value, diff --git a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc index 73d44bdcff..a0a1f62001 100644 --- a/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/tensorflow_utils_test.cc @@ -214,6 +214,14 @@ INSTANTIATE_TEST_SUITE_P( .value(), tensorflow::test::AsScalar(Eigen::half{1.0}), }, + { + "bfloat16", + testing::CreateArray(v0::DataType::DT_BFLOAT16, + testing::CreateArrayShape({}), + {Eigen::bfloat16{1.0}}) + .value(), + tensorflow::test::AsScalar(Eigen::bfloat16{1.0}), + }, { "float32", testing::CreateArray(v0::DataType::DT_FLOAT, diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils.cc b/tensorflow_federated/cc/core/impl/executors/xla_utils.cc index ab94b79db9..b7e462defa 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils.cc @@ -26,6 +26,7 @@ limitations under the License #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow_federated/cc/core/impl/executors/status_macros.h" #include "tensorflow_federated/proto/v0/array.pb.h" #include "tensorflow_federated/proto/v0/computation.pb.h" @@ -64,6 +65,8 @@ absl::StatusOr PrimitiveTypeFromDataType( return xla::C64; case v0::DataType::DT_COMPLEX128: return xla::C128; + case v0::DataType::DT_BFLOAT16: + return xla::BF16; default: return absl::UnimplementedError( absl::StrCat("Unexpected DataType found:", data_type)); @@ -118,6 +121,18 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField }); } +// Overload for Eigen::bflot16. +static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, + Eigen::bfloat16* dest) { + // Values of dtype ml_dtypes.bfloat16 are packed to and unpacked from a + // protobuf field of type int32 using the following logic in order to maintain + // compatibility with how other external environments (e.g. TensorFlow, Jax) + // represent values of ml_dtypes.bfloat16. + std::transform(src.begin(), src.end(), dest, [](int x) -> Eigen::bfloat16 { + return Eigen::numext::bit_cast(static_cast(x)); + }); +} + // Overload for complex. template static void CopyFromRepeatedField(const google::protobuf::RepeatedField& src, @@ -197,6 +212,13 @@ absl::StatusOr LiteralFromArray(const v0::Array& array_pb) { literal.data().begin()); return literal; } + case v0::Array::kBfloat16List: { + xla::Literal literal(TFF_TRY( + ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape()))); + CopyFromRepeatedField(array_pb.bfloat16_list().value(), + literal.data().begin()); + return literal; + } case v0::Array::kFloat32List: { xla::Literal literal(TFF_TRY( ShapeFromArrayShape(v0::DataType::DT_FLOAT, array_pb.shape()))); diff --git a/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc b/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc index bab406f0e8..a36a39d88a 100644 --- a/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc +++ b/tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc @@ -239,6 +239,19 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float16) { EXPECT_EQ(actual_literal, expected_literal); } +TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) { + const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( + v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}), + {Eigen::bfloat16{1.0}})); + + const xla::Literal& actual_literal = + TFF_ASSERT_OK(LiteralFromArray(array_pb)); + + xla::Literal expected_literal = + xla::LiteralUtil::CreateR0(Eigen::bfloat16{1.0}); + EXPECT_EQ(actual_literal, expected_literal); +} + TEST(LiteralFromArrayTest, TestReturnsLiteral_float32) { const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray( v0::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0})); diff --git a/tensorflow_federated/proto/v0/array.proto b/tensorflow_federated/proto/v0/array.proto index 1e47bf5741..aad5d34f7b 100644 --- a/tensorflow_federated/proto/v0/array.proto +++ b/tensorflow_federated/proto/v0/array.proto @@ -42,7 +42,7 @@ message Array { message BoolList { repeated bool value = 1; } - // INT8, INT16, INT32, UINT8, UINT16, HALF + // INT8, INT16, INT32, UINT8, UINT16, HALF, BFLOAT16 message IntList { repeated int32 value = 1; } @@ -80,6 +80,7 @@ message Array { DoubleList float64_list = 14; FloatList complex64_list = 15; DoubleList complex128_list = 16; + IntList bfloat16_list = 19; BytesList string_list = 17; } } diff --git a/tensorflow_federated/proto/v0/data_type.proto b/tensorflow_federated/proto/v0/data_type.proto index 65def9bd44..14b902a744 100644 --- a/tensorflow_federated/proto/v0/data_type.proto +++ b/tensorflow_federated/proto/v0/data_type.proto @@ -2,29 +2,30 @@ syntax = "proto3"; package tensorflow_federated.v0; -// Only simple data types are currently supported. enum DataType { + // Sorted by first kind (bool, int, uint, float, complex, string), then by + // wether the dtype exists natively in numpy, and finally by bit width. DT_INVALID = 0; - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; + DT_BOOL = 10; DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; + DT_INT16 = 5; + DT_INT32 = 3; DT_INT64 = 9; - DT_BOOL = 10; + DT_UINT8 = 4; DT_UINT16 = 17; - DT_COMPLEX128 = 18; - DT_HALF = 19; DT_UINT32 = 22; DT_UINT64 = 23; + DT_HALF = 19; + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_COMPLEX64 = 8; + DT_COMPLEX128 = 18; + DT_BFLOAT16 = 14; + DT_STRING = 7; reserved 11; // DT_QINT8 reserved 12; // DT_QUINT8 reserved 13; // DT_QINT32 - reserved 14; // DT_BFLOAT16 reserved 15; // DT_QINT16 reserved 16; // DT_QUINT16 } diff --git a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py index 855dc2e175..69db93340f 100644 --- a/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py +++ b/tensorflow_federated/python/core/environments/jax_frontend/jax_computation_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax +import ml_dtypes import numpy as np from tensorflow_federated.python.core.environments.jax_frontend import jax_computation @@ -90,6 +91,7 @@ def _comp(y, x): ('float16', computation_types.TensorType(np.float16)), ('float32', computation_types.TensorType(np.float32)), ('complex64', computation_types.TensorType(np.complex64)), + ('bfloat16', computation_types.TensorType(ml_dtypes.bfloat16)), ('generic', computation_types.TensorType(np.int32)), ('array', computation_types.TensorType(np.int32, shape=[3])), ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD b/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD index a30279b6de..c58e0a3e8f 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/BUILD @@ -49,6 +49,7 @@ py_test( "//tensorflow_federated/python/core/impl/context_stack:runtime_error_context", "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/impl/types:placements", + "//tensorflow_federated/python/core/impl/types:type_test_utils", ], ) diff --git a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py index 6026870ac2..1bede11a20 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py +++ b/tensorflow_federated/python/core/environments/tensorflow_frontend/tensorflow_computation_test.py @@ -26,6 +26,7 @@ from tensorflow_federated.python.core.impl.context_stack import runtime_error_context from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements +from tensorflow_federated.python.core.impl.types import type_test_utils def one_arg_fn(x): @@ -424,6 +425,20 @@ def _(): stack.current, runtime_error_context.RuntimeErrorContext ) + def test_custom_numpy_dtype(self): + + @tensorflow_computation.tf_computation(tf.bfloat16) + def foo(x): + return x + + type_test_utils.assert_types_identical( + foo.type_signature, + computation_types.FunctionType( + parameter=computation_types.TensorType(tf.bfloat16.as_numpy_dtype), + result=computation_types.TensorType(tf.bfloat16.as_numpy_dtype), + ), + ) + if __name__ == '__main__': absltest.main() diff --git a/tensorflow_federated/python/core/impl/compiler/array.py b/tensorflow_federated/python/core/impl/compiler/array.py index 5ef7650d28..6abec43661 100644 --- a/tensorflow_federated/python/core/impl/compiler/array.py +++ b/tensorflow_federated/python/core/impl/compiler/array.py @@ -15,6 +15,7 @@ from typing import Optional, Union +import ml_dtypes import numpy as np from tensorflow_federated.proto.v0 import array_pb2 @@ -67,6 +68,13 @@ def from_proto(array_pb: array_pb2.Array) -> Array: # compatibility with how other external environments (e.g., TensorFlow, JAX) # represent values of `np.float16`. value = np.asarray(value, np.uint16).view(np.float16).tolist() + elif dtype is ml_dtypes.bfloat16: + value = array_pb.bfloat16_list.value + # Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a + # protobuf field of type `int32` using the following logic in order to + # maintain compatibility with how other external environments (e.g., + # TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`. + value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist() elif dtype is np.float32: value = array_pb.float32_list.value elif dtype is np.float64: @@ -261,6 +269,17 @@ def _contains_type(value, classinfo): shape=shape_pb, complex128_list=array_pb2.Array.DoubleList(value=packed_value), ) + elif dtype is ml_dtypes.bfloat16: + # Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a + # protobuf field of type `int32` using the following logic in order to + # maintain compatibility with how other external environments (e.g., + # TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`. + value = np.asarray(value, ml_dtypes.bfloat16).view(np.uint16).tolist() + return array_pb2.Array( + dtype=dtype_pb, + shape=shape_pb, + bfloat16_list=array_pb2.Array.IntList(value=value), + ) elif dtype is np.str_: return array_pb2.Array( dtype=dtype_pb, @@ -284,13 +303,14 @@ def is_compatible_dtype(value: Array, dtype: type[np.generic]) -> bool: value: The value to check. dtype: The scalar `np.generic` to check against. """ - - # Check dtype kind. if isinstance(value, (np.ndarray, np.generic)): - value_dtype = value.dtype + value_dtype = value.dtype.type else: value_dtype = type(value) + if value_dtype is dtype: + return True + # Check dtype kind. if np.issubdtype(value_dtype, np.bool_): # Skip checking dtype size, `np.bool_` does not have a size. return dtype is np.bool_ diff --git a/tensorflow_federated/python/core/impl/compiler/array_test.py b/tensorflow_federated/python/core/impl/compiler/array_test.py index 1c6bf5313a..f11f9b4254 100644 --- a/tensorflow_federated/python/core/impl/compiler/array_test.py +++ b/tensorflow_federated/python/core/impl/compiler/array_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized +import ml_dtypes import numpy as np from tensorflow_federated.proto.v0 import array_pb2 @@ -118,6 +119,19 @@ class FromProtoTest(parameterized.TestCase): ), np.float16(1.0), ), + ( + 'bfloat16', + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_BFLOAT16, + shape=array_pb2.ArrayShape(dim=[]), + bfloat16_list=array_pb2.Array.IntList( + value=[ + np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() + ] + ), + ), + ml_dtypes.bfloat16(1.0), + ), ( 'float32', array_pb2.Array( @@ -349,6 +363,30 @@ class ToProtoTest(parameterized.TestCase): int32_list=array_pb2.Array.IntList(value=[1]), ), ), + ( + 'generic_float16', + np.float16(1.0), + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_HALF, + shape=array_pb2.ArrayShape(dim=[]), + float16_list=array_pb2.Array.IntList( + value=[np.asarray(1.0, np.float16).view(np.uint16).item()] + ), + ), + ), + ( + 'generic_bfloat16', + ml_dtypes.bfloat16(1.0), + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_BFLOAT16, + shape=array_pb2.ArrayShape(dim=[]), + bfloat16_list=array_pb2.Array.IntList( + value=[ + np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() + ] + ), + ), + ), ( 'generic_str', np.str_('abc'), @@ -377,7 +415,7 @@ class ToProtoTest(parameterized.TestCase): ), ), ( - 'array_int32_epmty', + 'array_int32_empty', np.array([], np.int32), array_pb2.Array( dtype=data_type_pb2.DataType.DT_INT32, @@ -385,6 +423,30 @@ class ToProtoTest(parameterized.TestCase): int32_list=array_pb2.Array.IntList(value=[]), ), ), + ( + 'array_float16', + np.array([1.0], np.float16), + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_HALF, + shape=array_pb2.ArrayShape(dim=[1]), + float16_list=array_pb2.Array.IntList( + value=[np.asarray(1.0, np.float16).view(np.uint16).item()] + ), + ), + ), + ( + 'array_bfloat16', + np.array([1.0], ml_dtypes.bfloat16), + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_BFLOAT16, + shape=array_pb2.ArrayShape(dim=[1]), + bfloat16_list=array_pb2.Array.IntList( + value=[ + np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() + ] + ), + ), + ), ( 'array_str', np.array(['abc', 'def'], np.str_), @@ -606,6 +668,23 @@ def test_returns_value_with_no_dtype_hint(self, value, expected_value): complex128_list=array_pb2.Array.DoubleList(value=[1.0, 1.0]), ), ), + ( + 'bfloat16', + # Note: we must not use Python `float` here because ml_dtypes.bfloat16 + # is declared as kind `V` (void) not `f` (float) to prevent numpy from + # trying to equate float16 and bfloat16 (which are not compatible). + ml_dtypes.bfloat16(1.0), + ml_dtypes.bfloat16, + array_pb2.Array( + dtype=data_type_pb2.DataType.DT_BFLOAT16, + shape=array_pb2.ArrayShape(dim=[]), + bfloat16_list=array_pb2.Array.IntList( + value=[ + np.asarray(1.0, ml_dtypes.bfloat16).view(np.uint16).item() + ] + ), + ), + ), ( 'str', 'abc', @@ -843,6 +922,7 @@ class IsCompatibleDtypeTest(parameterized.TestCase): ('float64', 1.0, np.float64), ('complex64', (1.0 + 1.0j), np.complex64), ('complex128', (1.0 + 1.0j), np.complex128), + ('bfloat16', ml_dtypes.bfloat16(1.0), ml_dtypes.bfloat16), ('str', 'abc', np.str_), ('bytes', b'abc', np.str_), ('generic_int32', np.int32(1), np.int32), @@ -863,6 +943,7 @@ def test_returns_true(self, value, dtype): @parameterized.named_parameters( ('scalar_incompatible_dtype_kind', 1, np.float32), + ('scalar_incompatible_dtype_kind_bfloat16', 1.0, ml_dtypes.bfloat16), ('scalar_incompatible_dtype_size_int', np.iinfo(np.int64).max, np.int32), ( 'scalar_incompatible_dtype_size_float', diff --git a/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py b/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py index 9e7663d6dc..e1a50e040c 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py +++ b/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py @@ -14,6 +14,7 @@ from absl.testing import absltest from absl.testing import parameterized +import ml_dtypes import numpy as np import tree @@ -1725,6 +1726,11 @@ class LiteralTest(parameterized.TestCase): ('float64', 1.0, computation_types.TensorType(np.float64)), ('complex64', (1.0 + 1.0j), computation_types.TensorType(np.complex64)), ('complex128', (1.0 + 1.0j), computation_types.TensorType(np.complex128)), + ( + 'bfloat16', + ml_dtypes.bfloat16(1.0), + computation_types.TensorType(ml_dtypes.bfloat16), + ), ('str', 'a', computation_types.TensorType(np.str_)), ('bytes', b'a', computation_types.TensorType(np.str_)), ('generic_int', np.int32(1), computation_types.TensorType(np.int32)), @@ -1785,8 +1791,8 @@ class LiteralTest(parameterized.TestCase): def test_init_does_not_raise_value_error(self, value, type_signature): try: building_blocks.Literal(value, type_signature) - except ValueError: - self.fail('Raised `ValueError` unexpectedly.') + except ValueError as e: + self.fail('Raised `ValueError` unexpectedly: %s', e) @parameterized.named_parameters( ('str', 'a', computation_types.TensorType(np.str_), b'a'), diff --git a/tensorflow_federated/python/core/impl/types/computation_types.py b/tensorflow_federated/python/core/impl/types/computation_types.py index 82d8bea51a..0355558c69 100644 --- a/tensorflow_federated/python/core/impl/types/computation_types.py +++ b/tensorflow_federated/python/core/impl/types/computation_types.py @@ -1315,6 +1315,7 @@ def _lines_for_type(type_spec, formatted): if type_spec.shape is None: return ['{!r}(shape=None)'.format(type_spec.dtype.name)] elif type_spec.shape: + def _value_string(value): return str(value) if value is not None else '?' diff --git a/tensorflow_federated/python/core/impl/types/dtype_utils.py b/tensorflow_federated/python/core/impl/types/dtype_utils.py index ec213d3a2b..2ce4ef0313 100644 --- a/tensorflow_federated/python/core/impl/types/dtype_utils.py +++ b/tensorflow_federated/python/core/impl/types/dtype_utils.py @@ -15,6 +15,7 @@ from collections.abc import Mapping from typing import Union +import ml_dtypes import numpy as np from tensorflow_federated.proto.v0 import data_type_pb2 @@ -36,6 +37,7 @@ data_type_pb2.DataType.DT_DOUBLE: np.float64, data_type_pb2.DataType.DT_COMPLEX64: np.complex64, data_type_pb2.DataType.DT_COMPLEX128: np.complex128, + data_type_pb2.DataType.DT_BFLOAT16: ml_dtypes.bfloat16, data_type_pb2.DataType.DT_STRING: np.str_, } @@ -66,6 +68,7 @@ def from_proto( np.float64: data_type_pb2.DataType.DT_DOUBLE, np.complex64: data_type_pb2.DataType.DT_COMPLEX64, np.complex128: data_type_pb2.DataType.DT_COMPLEX128, + ml_dtypes.bfloat16: data_type_pb2.DataType.DT_BFLOAT16, np.str_: data_type_pb2.DataType.DT_STRING, } diff --git a/tensorflow_federated/python/core/impl/types/dtype_utils_test.py b/tensorflow_federated/python/core/impl/types/dtype_utils_test.py index 9efeda4f60..cb03392c58 100644 --- a/tensorflow_federated/python/core/impl/types/dtype_utils_test.py +++ b/tensorflow_federated/python/core/impl/types/dtype_utils_test.py @@ -14,6 +14,7 @@ from absl.testing import absltest from absl.testing import parameterized +import ml_dtypes import numpy as np from tensorflow_federated.python.core.impl.types import dtype_utils @@ -47,6 +48,7 @@ def test_to_proto_raises_not_implemented_error(self, dtype): ('uint16', np.uint16), ('uint32', np.uint32), ('uint64', np.uint64), + ('bfloat16', ml_dtypes.bfloat16), ('float16', np.float16), ('float32', np.float32), ('float64', np.float64), diff --git a/tensorflow_federated/python/core/impl/types/type_analysis.py b/tensorflow_federated/python/core/impl/types/type_analysis.py index a58d4d9ad1..fd877a0161 100644 --- a/tensorflow_federated/python/core/impl/types/type_analysis.py +++ b/tensorflow_federated/python/core/impl/types/type_analysis.py @@ -17,6 +17,7 @@ from collections.abc import Callable from typing import Optional +import ml_dtypes import numpy as np from tensorflow_federated.python.common_libs import py_typecheck @@ -355,7 +356,10 @@ def check_is_sum_compatible(type_spec, type_spec_context=None): type_spec_context = type_spec py_typecheck.check_type(type_spec_context, computation_types.Type) if isinstance(type_spec, computation_types.TensorType): - if not np.issubdtype(type_spec.dtype, np.number): + if not ( + np.issubdtype(type_spec.dtype, np.number) + or type_spec.dtype == ml_dtypes.bfloat16 + ): raise SumIncompatibleError( type_spec, type_spec_context, f'{type_spec.dtype} is not numeric' ) diff --git a/tensorflow_federated/python/core/impl/types/type_analysis_test.py b/tensorflow_federated/python/core/impl/types/type_analysis_test.py index 592d619d94..a3b71b7207 100644 --- a/tensorflow_federated/python/core/impl/types/type_analysis_test.py +++ b/tensorflow_federated/python/core/impl/types/type_analysis_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest from absl.testing import parameterized +import ml_dtypes import numpy as np from tensorflow_federated.python.common_libs import structure @@ -237,6 +238,7 @@ class CheckIsSumCompatibleTest(parameterized.TestCase): @parameterized.named_parameters([ ('tensor_type', computation_types.TensorType(np.int32)), + ('bfloat16_type', computation_types.TensorType(ml_dtypes.bfloat16)), ( 'struct_type_int', computation_types.StructType(