Skip to content

Commit

Permalink
Simplify by using dtype_utils.is_valid_dtype in computation_types.py …
Browse files Browse the repository at this point in the history
…instead of a local list of valid dtypes.

This also results in code like `tff.types.TensorStype(tf.string)` now raising an error (the intended behavior), and must be changed to `tff.types.TensorStype(np.str_)`.

PiperOrigin-RevId: 662644692
  • Loading branch information
ZacharyGarrett authored and copybara-github committed Aug 13, 2024
1 parent 9d381b8 commit 634bb40
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 29 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ and this project adheres to
* Added some TFF executor classes to the public API (CPPExecutorFactory,
ResourceManagingExecutorFactory, RemoteExecutor, RemoteExecutorGrpcStub).

### Fixed

* A bug where `tf.string` was mistakenly allowed as a dtype to
`tff.types.TensorType`. This now must be `np.str_`.

### Changed

* `tff.Computation` and `tff.framework.ConcreteComputation` to be able to
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/impl/types/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ py_library(
srcs = ["computation_types.py"],
deps = [
":array_shape",
":dtype_utils",
":placements",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
Expand Down
32 changes: 4 additions & 28 deletions tensorflow_federated/python/core/impl/types/computation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.types import array_shape
from tensorflow_federated.python.core.impl.types import dtype_utils
from tensorflow_federated.python.core.impl.types import placements

T = TypeVar('T')
Expand Down Expand Up @@ -307,30 +308,6 @@ def _is_dtype_like(obj: object) -> TypeGuard[_DtypeLike]:
return isinstance(obj, np.dtype)


_ALLOWED_NP_DTYPES = [
np.bool_,
np.bytes_,
np.complex128,
np.complex64,
np.float16,
np.float32,
np.float64,
np.int16,
np.int32,
np.int64,
np.int8,
np.str_,
np.uint16,
np.uint32,
np.uint64,
np.uint8,
]


def _is_allowed_np_dtype(dtype: np.dtype) -> bool:
return dtype in _ALLOWED_NP_DTYPES


def _is_array_shape_like(
obj: object,
) -> TypeGuard[Union[array_shape._ArrayShapeLike]]:
Expand Down Expand Up @@ -359,13 +336,12 @@ def _to_dtype(dtype: _DtypeLike) -> np.dtype:
"""
if isinstance(dtype, np.dtype):
dtype = dtype.type
if dtype == np.bytes_:
if dtype is np.bytes_:
dtype = np.str_
dtype = np.dtype(dtype)

if not _is_allowed_np_dtype(dtype):
if not dtype_utils.is_valid_dtype(dtype):
raise NotImplementedError(f'Unexpected `dtype` found: {dtype}.')
return dtype
return np.dtype(dtype)


class TensorType(Type, metaclass=_Intern):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def result(self):
@federated_computation.federated_computation(
computation_types.FederatedType(
collections.OrderedDict(
custom_sum=computation_types.TensorType(tf.string),
custom_sum=computation_types.TensorType(np.str_),
),
placements.CLIENTS,
)
Expand Down

0 comments on commit 634bb40

Please sign in to comment.