You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def get_batch(split):
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == 'train':
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
else:
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
if device_type == 'cuda':
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
else:
x, y = x.to(device), y.to(device)
return x, y
Hi,Thanks for the excellent work. I have a question about the get_batch,this code can't make the data across each distribute process unique, like DistributedSampler in pytorch does.(indices = indices[self.rank:self.total_size:self.num_replicas])The data maybe overlap。becase every process take data by the order ("x = torch.randint(len(data) - block_size, (batch_size,))" ). Is this a efficiency problem?
The text was updated successfully, but these errors were encountered:
Hi,Thanks for the excellent work. I have a question about the get_batch,this code can't make the data across each distribute process unique, like DistributedSampler in pytorch does.(indices = indices[self.rank:self.total_size:self.num_replicas])The data maybe overlap。becase every process take data by the order ("x = torch.randint(len(data) - block_size, (batch_size,))" ). Is this a efficiency problem?
The text was updated successfully, but these errors were encountered: