diff --git a/whisper/audio.py b/whisper/audio.py index cf6c66ad9..b895acd3e 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -20,39 +20,48 @@ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token - - def load_audio(file: str, sr: int = SAMPLE_RATE): """ - Open an audio file and read as mono waveform, resampling as necessary + Loads an audio file as a mono waveform, resampling to the specified sample rate. Parameters ---------- - file: str - The audio file to open - - sr: int - The sample rate to resample the audio if necessary + file : str + Path to the audio file. + sr : int, optional + Target sample rate for resampling, defaults to SAMPLE_RATE. Returns ------- - A NumPy array containing the audio waveform, in float32 dtype. + np.ndarray + 1D NumPy array of the audio waveform, normalized between -1 and 1. + + Raises + ------ + RuntimeError + If the audio cannot be loaded. + + Notes + ----- + Requires ffmpeg installed and accessible in the system's PATH. """ # This launches a subprocess to decode audio while down-mixing # and resampling as necessary. Requires the ffmpeg CLI in PATH. # fmt: off + cmd = [ - "ffmpeg", - "-nostdin", - "-threads", "0", - "-i", file, - "-f", "s16le", - "-ac", "1", - "-acodec", "pcm_s16le", - "-ar", str(sr), - "-" - ] + "ffmpeg", # Command to run the ffmpeg tool. + "-nostdin", # Prevents ffmpeg from reading from stdin. + "-threads", "0", # Uses all available CPU cores for processing. + "-i", file, # Specifies the input file path. + "-f", "s16le", # Sets the output format to 16-bit PCM. + "-ac", "1", # Converts audio to mono (1 channel). + "-acodec", "pcm_s16le", # Specifies the audio codec as PCM signed 16-bit little-endian. + "-ar", str(sr), # Resamples the audio to the specified sample rate. + "-" # Outputs the processed audio to stdout. +] + # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout @@ -63,8 +72,22 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ - Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + Pads or trims the input array to a specified length along the given axis. + + Parameters: + - array: Input array (torch.Tensor or np.ndarray). + - length: Target length along the specified axis (default is N_SAMPLES). + - axis: Axis to pad or trim (default is -1 for the last axis). + + Returns: + - The modified array, either padded with zeros or trimmed to the target length. + + Note: + - The function handles both PyTorch tensors and NumPy arrays, applying appropriate methods + for padding and trimming depending on the array type. + """ if torch.is_tensor(array): if array.shape[axis] > length: @@ -91,14 +114,30 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @lru_cache(maxsize=None) def mel_filters(device, n_mels: int) -> torch.Tensor: """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Allows decoupling librosa dependency; saved using: - - np.savez_compressed( - "mel_filters.npz", - mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), - mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), - ) + Loads a precomputed Mel filterbank matrix for converting STFT to a Mel spectrogram. + + Parameters + ---------- + device : torch.device + The device (CPU or GPU) to load the tensor onto. + n_mels : int + The number of Mel bands, must be either 80 or 128. + + Returns + ------- + torch.Tensor + A tensor containing the Mel filterbank matrix. + + Raises + ------ + AssertionError + If `n_mels` is not supported. + + Notes + ----- + The Mel filterbank matrices are saved in a compressed npz file, which decouples + the dependency on librosa for generating these filters. + """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" @@ -114,27 +153,32 @@ def log_mel_spectrogram( device: Optional[Union[str, torch.device]] = None, ): """ - Compute the log-Mel spectrogram of + Computes the log-Mel spectrogram of an audio waveform. Parameters ---------- - audio: Union[str, np.ndarray, torch.Tensor], shape = (*) - The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz - - n_mels: int - The number of Mel-frequency filters, only 80 is supported - - padding: int - Number of zero samples to pad to the right - - device: Optional[Union[str, torch.device]] - If given, the audio tensor is moved to this device before STFT + audio : Union[str, np.ndarray, torch.Tensor] + The audio input, either as a file path, NumPy array, or Torch tensor. + The waveform should be in 16 kHz. + n_mels : int, optional + The number of Mel-frequency filters, only 80 is supported. Defaults to 80. + padding : int, optional + Number of zero samples to pad at the end of the audio. Defaults to 0. + device : Optional[Union[str, torch.device]], optional + The device to perform computations on. If provided, the audio tensor is moved + to this device. Defaults to None. Returns ------- - torch.Tensor, shape = (80, n_frames) - A Tensor that contains the Mel spectrogram + torch.Tensor + A tensor containing the Mel spectrogram with shape (80, n_frames). + + Notes + ----- + The function expects a 16 kHz sampling rate for the input audio and uses a Hann + window for the Short-Time Fourier Transform (STFT). """ + if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio)