-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset.py
executable file
·82 lines (65 loc) · 2.82 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
## Modified based on https://github.com/pytorch/audio/blob/master/torchaudio/datasets/speechcommands.py
import os
from pathlib import Path
import numpy as np
import torch
from torchvision import datasets, models, transforms
from torch.utils.data.distributed import DistributedSampler
from scipy.io.wavfile import read as wavread
from typing import Tuple
import torchaudio
from torch.utils.data import Dataset
from torch import Tensor
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
def fix_length(tensor, length):
assert len(tensor.shape) == 2 and tensor.shape[0] == 1
if tensor.shape[1] > length:
return tensor[:,:length]
elif tensor.shape[1] < length:
return torch.cat([tensor, torch.zeros(1, length-tensor.shape[1])], dim=1)
else:
return tensor
def load_speechcommands_item(filepath: str, path: str):
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
speaker, _ = os.path.splitext(filename)
speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
utterance_number = int(utterance_number)
# Load audio
waveform, sample_rate = torchaudio.load(filepath)
return (fix_length(waveform, length=16000), sample_rate, label)
class SPEECHCOMMANDS(Dataset):
"""
Create a Dataset for Speech Commands. Each item is a tuple of the form:
waveform, sample_rate, label
"""
def __init__(self, root: str, folder_in_archive: str):
self._path = os.path.join(root, folder_in_archive)
# walker = walk_files(self._path, suffix=".wav", prefix=True)
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
fileid = self._walker[n]
return load_speechcommands_item(fileid, self._path)
def __len__(self) -> int:
return len(self._walker)
def load_Speech_commands(path, batch_size=4, num_gpus=1):
"""
Load speech commands dataset
"""
Speech_commands_dataset = SPEECHCOMMANDS(root=path, folder_in_archive='')
# distributed sampler
train_sampler = DistributedSampler(Speech_commands_dataset) if num_gpus > 1 else None
trainloader = torch.utils.data.DataLoader(Speech_commands_dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=4,
pin_memory=False,
drop_last=True)
return trainloader