Skip to content

Commit

Permalink
Add back support for bfloat16 dtypes.
Browse files Browse the repository at this point in the history
This adds a dependency on `ml_dtypes` package which provides the custom numpy
dtype for `bfloat16` which is the implementation of `tf.bfloat16`.

PiperOrigin-RevId: 662652837
  • Loading branch information
ZacharyGarrett authored and copybara-github committed Aug 13, 2024
1 parent 634bb40 commit 1acc172
Show file tree
Hide file tree
Showing 22 changed files with 249 additions and 20 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to
histograms, and a boolean indicating whether Laplace noise was used.
* Added some TFF executor classes to the public API (CPPExecutorFactory,
ResourceManagingExecutorFactory, RemoteExecutor, RemoteExecutorGrpcStub).
* Added support for `bfloat16` dtypes from the `ml_dtypes` package.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
'grpcio~=1.46',
'jaxlib==0.4.14',
'jax==0.4.14',
'ml_dtypes>=0.2.0,==0.2.*',
'numpy~=1.25',
'portpicker~=1.6',
'scipy~=1.9.3',
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ google-vizier==0.1.11
grpcio~=1.46
jaxlib==0.4.14
jax==0.4.14
ml_dtypes>=0.2.0,==0.2.*
numpy~=1.25
portpicker~=1.6
scipy~=1.9.3
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/cc/core/impl/executors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@ cc_library(
"@org_tensorflow//tensorflow/compiler/xla:literal",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla:types",
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto_cc",
],
)

Expand Down
24 changes: 24 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/array_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,30 @@ inline absl::StatusOr<v0::Array> CreateArray(
return array_pb;
}

// Overload for Eigen::bfloat16.
inline absl::StatusOr<v0::Array> CreateArray(
v0::DataType dtype, v0::ArrayShape shape_pb,
std::initializer_list<const Eigen::bfloat16> values) {
v0::Array array_pb;
array_pb.set_dtype(dtype);
array_pb.mutable_shape()->Swap(&shape_pb);
switch (dtype) {
case v0::DataType::DT_BFLOAT16: {
auto size = values.size();
array_pb.mutable_bfloat16_list()->mutable_value()->Reserve(size);
for (auto element : values) {
array_pb.mutable_bfloat16_list()->mutable_value()->AddAlreadyReserved(
Eigen::numext::bit_cast<uint16_t>(element));
}
break;
}
default:
return absl::UnimplementedError(
absl::StrCat("Unexpected DataType found:", dtype));
}
return array_pb;
}

// Overload for complex.
template <typename T>
inline absl::StatusOr<v0::Array> CreateArray(
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/tensorflow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>
return Eigen::numext::bit_cast<Eigen::half>(static_cast<uint16_t>(x));
});
}
// Overload for Eigen::bfloat16.
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>& src,
Eigen::bfloat16* dest) {
// Values of dtype ml_dtypes.bfloat16 are packed to and unpacked from a
// protobuf field of type int32 using the following logic in order to maintain
// compatibility with how other external environments (e.g. TensorFlow, Jax)
// represent values of ml_dtypes.bfloat16.
std::transform(src.begin(), src.end(), dest, [](int x) -> Eigen::bfloat16 {
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<uint16_t>(x));
});
}

// Overload for complex.
template <typename T>
Expand Down Expand Up @@ -186,6 +197,14 @@ absl::StatusOr<tensorflow::Tensor> TensorFromArray(const v0::Array& array_pb) {
tensor.flat<Eigen::half>().data());
return tensor;
}
case v0::Array::kBfloat16List: {
tensorflow::Tensor tensor(
tensorflow::DataTypeToEnum<Eigen::bfloat16>::value,
TFF_TRY(TensorShapeFromArrayShape(array_pb.shape())));
CopyFromRepeatedField(array_pb.bfloat16_list().value(),
tensor.flat<Eigen::bfloat16>().data());
return tensor;
}
case v0::Array::kFloat32List: {
tensorflow::Tensor tensor(
tensorflow::DataTypeToEnum<float>::value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ INSTANTIATE_TEST_SUITE_P(
.value(),
tensorflow::test::AsScalar(Eigen::half{1.0}),
},
{
"bfloat16",
testing::CreateArray(v0::DataType::DT_BFLOAT16,
testing::CreateArrayShape({}),
{Eigen::bfloat16{1.0}})
.value(),
tensorflow::test::AsScalar(Eigen::bfloat16{1.0}),
},
{
"float32",
testing::CreateArray(v0::DataType::DT_FLOAT,
Expand Down
22 changes: 22 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/xla_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow_federated/cc/core/impl/executors/status_macros.h"
#include "tensorflow_federated/proto/v0/array.pb.h"
#include "tensorflow_federated/proto/v0/computation.pb.h"
Expand Down Expand Up @@ -64,6 +65,8 @@ absl::StatusOr<xla::PrimitiveType> PrimitiveTypeFromDataType(
return xla::C64;
case v0::DataType::DT_COMPLEX128:
return xla::C128;
case v0::DataType::DT_BFLOAT16:
return xla::BF16;
default:
return absl::UnimplementedError(
absl::StrCat("Unexpected DataType found:", data_type));
Expand Down Expand Up @@ -118,6 +121,18 @@ static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>
});
}

// Overload for Eigen::bflot16.
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<int32_t>& src,
Eigen::bfloat16* dest) {
// Values of dtype ml_dtypes.bfloat16 are packed to and unpacked from a
// protobuf field of type int32 using the following logic in order to maintain
// compatibility with how other external environments (e.g. TensorFlow, Jax)
// represent values of ml_dtypes.bfloat16.
std::transform(src.begin(), src.end(), dest, [](int x) -> Eigen::bfloat16 {
return Eigen::numext::bit_cast<Eigen::bfloat16>(static_cast<uint16_t>(x));
});
}

// Overload for complex.
template <typename T>
static void CopyFromRepeatedField(const google::protobuf::RepeatedField<T>& src,
Expand Down Expand Up @@ -197,6 +212,13 @@ absl::StatusOr<xla::Literal> LiteralFromArray(const v0::Array& array_pb) {
literal.data<xla::half>().begin());
return literal;
}
case v0::Array::kBfloat16List: {
xla::Literal literal(TFF_TRY(
ShapeFromArrayShape(v0::DataType::DT_BFLOAT16, array_pb.shape())));
CopyFromRepeatedField(array_pb.bfloat16_list().value(),
literal.data<xla::bfloat16>().begin());
return literal;
}
case v0::Array::kFloat32List: {
xla::Literal literal(TFF_TRY(
ShapeFromArrayShape(v0::DataType::DT_FLOAT, array_pb.shape())));
Expand Down
13 changes: 13 additions & 0 deletions tensorflow_federated/cc/core/impl/executors/xla_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ TEST(LiteralFromArrayTest, TestReturnsLiteral_float16) {
EXPECT_EQ(actual_literal, expected_literal);
}

TEST(LiteralFromArrayTest, TestReturnsLiteral_bfloat16) {
const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray(
v0::DataType::DT_BFLOAT16, testing::CreateArrayShape({}),
{Eigen::bfloat16{1.0}}));

const xla::Literal& actual_literal =
TFF_ASSERT_OK(LiteralFromArray(array_pb));

xla::Literal expected_literal =
xla::LiteralUtil::CreateR0(Eigen::bfloat16{1.0});
EXPECT_EQ(actual_literal, expected_literal);
}

TEST(LiteralFromArrayTest, TestReturnsLiteral_float32) {
const v0::Array& array_pb = TFF_ASSERT_OK(testing::CreateArray(
v0::DataType::DT_FLOAT, testing::CreateArrayShape({}), {1.0}));
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_federated/proto/v0/array.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ message Array {
message BoolList {
repeated bool value = 1;
}
// INT8, INT16, INT32, UINT8, UINT16, HALF
// INT8, INT16, INT32, UINT8, UINT16, HALF, BFLOAT16
message IntList {
repeated int32 value = 1;
}
Expand Down Expand Up @@ -80,6 +80,7 @@ message Array {
DoubleList float64_list = 14;
FloatList complex64_list = 15;
DoubleList complex128_list = 16;
IntList bfloat16_list = 19;
BytesList string_list = 17;
}
}
25 changes: 13 additions & 12 deletions tensorflow_federated/proto/v0/data_type.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@ syntax = "proto3";

package tensorflow_federated.v0;

// Only simple data types are currently supported.
enum DataType {
// Sorted by first kind (bool, int, uint, float, complex, string), then by
// wether the dtype exists natively in numpy, and finally by bit width.
DT_INVALID = 0;
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_BOOL = 10;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8;
DT_INT16 = 5;
DT_INT32 = 3;
DT_INT64 = 9;
DT_BOOL = 10;
DT_UINT8 = 4;
DT_UINT16 = 17;
DT_COMPLEX128 = 18;
DT_HALF = 19;
DT_UINT32 = 22;
DT_UINT64 = 23;
DT_HALF = 19;
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_COMPLEX64 = 8;
DT_COMPLEX128 = 18;
DT_BFLOAT16 = 14;
DT_STRING = 7;

reserved 11; // DT_QINT8
reserved 12; // DT_QUINT8
reserved 13; // DT_QINT32
reserved 14; // DT_BFLOAT16
reserved 15; // DT_QINT16
reserved 16; // DT_QUINT16
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
import ml_dtypes
import numpy as np

from tensorflow_federated.python.core.environments.jax_frontend import jax_computation
Expand Down Expand Up @@ -90,6 +91,7 @@ def _comp(y, x):
('float16', computation_types.TensorType(np.float16)),
('float32', computation_types.TensorType(np.float32)),
('complex64', computation_types.TensorType(np.complex64)),
('bfloat16', computation_types.TensorType(ml_dtypes.bfloat16)),
('generic', computation_types.TensorType(np.int32)),
('array', computation_types.TensorType(np.int32, shape=[3])),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ py_test(
"//tensorflow_federated/python/core/impl/context_stack:runtime_error_context",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_test_utils",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tensorflow_federated.python.core.impl.context_stack import runtime_error_context
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_test_utils


def one_arg_fn(x):
Expand Down Expand Up @@ -424,6 +425,20 @@ def _():
stack.current, runtime_error_context.RuntimeErrorContext
)

def test_custom_numpy_dtype(self):

@tensorflow_computation.tf_computation(tf.bfloat16)
def foo(x):
return x

type_test_utils.assert_types_identical(
foo.type_signature,
computation_types.FunctionType(
parameter=computation_types.TensorType(tf.bfloat16.as_numpy_dtype),
result=computation_types.TensorType(tf.bfloat16.as_numpy_dtype),
),
)


if __name__ == '__main__':
absltest.main()
26 changes: 23 additions & 3 deletions tensorflow_federated/python/core/impl/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import Optional, Union

import ml_dtypes
import numpy as np

from tensorflow_federated.proto.v0 import array_pb2
Expand Down Expand Up @@ -67,6 +68,13 @@ def from_proto(array_pb: array_pb2.Array) -> Array:
# compatibility with how other external environments (e.g., TensorFlow, JAX)
# represent values of `np.float16`.
value = np.asarray(value, np.uint16).view(np.float16).tolist()
elif dtype is ml_dtypes.bfloat16:
value = array_pb.bfloat16_list.value
# Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a
# protobuf field of type `int32` using the following logic in order to
# maintain compatibility with how other external environments (e.g.,
# TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`.
value = np.asarray(value, np.uint16).view(ml_dtypes.bfloat16).tolist()
elif dtype is np.float32:
value = array_pb.float32_list.value
elif dtype is np.float64:
Expand Down Expand Up @@ -261,6 +269,17 @@ def _contains_type(value, classinfo):
shape=shape_pb,
complex128_list=array_pb2.Array.DoubleList(value=packed_value),
)
elif dtype is ml_dtypes.bfloat16:
# Values of dtype `ml_dtypes.bfloat16` are packed to and unpacked from a
# protobuf field of type `int32` using the following logic in order to
# maintain compatibility with how other external environments (e.g.,
# TensorFlow, JAX) represent values of `ml_dtypes.bfloat16`.
value = np.asarray(value, ml_dtypes.bfloat16).view(np.uint16).tolist()
return array_pb2.Array(
dtype=dtype_pb,
shape=shape_pb,
bfloat16_list=array_pb2.Array.IntList(value=value),
)
elif dtype is np.str_:
return array_pb2.Array(
dtype=dtype_pb,
Expand All @@ -284,13 +303,14 @@ def is_compatible_dtype(value: Array, dtype: type[np.generic]) -> bool:
value: The value to check.
dtype: The scalar `np.generic` to check against.
"""

# Check dtype kind.
if isinstance(value, (np.ndarray, np.generic)):
value_dtype = value.dtype
value_dtype = value.dtype.type
else:
value_dtype = type(value)
if value_dtype is dtype:
return True

# Check dtype kind.
if np.issubdtype(value_dtype, np.bool_):
# Skip checking dtype size, `np.bool_` does not have a size.
return dtype is np.bool_
Expand Down
Loading

0 comments on commit 1acc172

Please sign in to comment.