Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable setting drop_remainder #251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions airio/_src/pygrain/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_lazy_dataset(
shard_info: core_dataset_providers.ShardInfo | None,
num_epochs: int | None,
num_prefetch_threads: int | None,
drop_remainder: bool = False,
) -> lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset:
"""Returns a lazy dataset for Task source and preprocessors."""
# Step 1: Get Source.
Expand Down Expand Up @@ -156,7 +157,7 @@ def get_lazy_dataset(
runtime_preps.extend(runtime_preprocessors)
if batch_size:
runtime_preps.append(
grain.Batch(batch_size=batch_size, drop_remainder=False)
grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)
)
unused_next_epoch_rng, prep_rng = jax.random.split(next_epoch_rng)
ds, _, _ = _apply_preprocessors_to_lazy_dataset(
Expand Down Expand Up @@ -185,6 +186,7 @@ def get_dataset(
num_epochs: int | None = 1,
num_prefetch_threads: int | None = None,
num_workers: int | None = 0,
drop_remainder: bool = False,
) -> clu_dataset_iterator.DatasetIterator:
"""Returns the dataset iterator as per the task configuration."""
# TODO(b/311720936): Until Task preprocessing is fully switched to
Expand All @@ -201,6 +203,7 @@ def get_dataset(
shard_info=shard_info,
num_epochs=num_epochs,
num_prefetch_threads=num_prefetch_threads,
drop_remainder=drop_remainder,
)
if num_epochs is None:
ds = lazy_dataset.RepeatLazyMapDataset(ds, num_epochs=None)
Expand Down Expand Up @@ -230,7 +233,7 @@ def get_dataset(
if runtime_preprocessors:
ops.extend(runtime_preprocessors)
if batch_size:
ops.append(grain.Batch(batch_size=batch_size, drop_remainder=False))
ops.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))

# Add runtime args
runtime_args = core_preprocessors_lib.AirIOInjectedRuntimeArgs(
Expand Down Expand Up @@ -300,6 +303,7 @@ def get_dataset_by_step(
batch_size: int | None = None,
shuffle: bool = True,
seed: int | None = 0,
drop_remainder: bool = False,
) -> Iterable[Iterable[Mapping[str, Any]]]:
"""Returns a step-by-step transformation of a sample of records.

Expand All @@ -314,6 +318,7 @@ def get_dataset_by_step(
batch_size: the batch size.
shuffle: whether to shuffle or not.
seed: dataset seed.
drop_remainder: whether to drop the last batch if it's smaller than batch_size.

Returns: a list indexed by processing step. For example:
|-----------------------------|
Expand Down Expand Up @@ -345,7 +350,7 @@ def get_dataset_by_step(
if runtime_preprocessors:
all_ops.extend(runtime_preprocessors)
if batch_size:
all_ops.append(grain.Batch(batch_size=batch_size, drop_remainder=False))
all_ops.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))

# Raw data
records_step0 = self._load_data(source=source, sampler=sampler, ops=[])
Expand Down Expand Up @@ -432,6 +437,7 @@ def get_lazy_dataset(
shard_info: core_dataset_providers.ShardInfo | None = None,
num_epochs: int | None = 1,
num_prefetch_threads: int | None = None,
drop_remainder: bool = False,
) -> lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset:
"""Returns a lazy dataset for the Mixture."""
if num_epochs is None and shuffle:
Expand All @@ -454,6 +460,7 @@ def get_lazy_dataset(
shard_info=shard_info,
num_epochs=num_epochs,
num_prefetch_threads=num_prefetch_threads,
drop_remainder=drop_remainder,
)
)
proportions.append(self.get_proportion(task))
Expand Down Expand Up @@ -493,7 +500,7 @@ def get_lazy_dataset(
post_mix_preps.extend(runtime_preprocessors)
if batch_size:
post_mix_preps.append(
grain.Batch(batch_size=batch_size, drop_remainder=False)
grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)
)
# Note: Use updated runtime args from the first Task. All updated runtime
# args must match, or mixing won't work (compute all updated runtime args
Expand Down