mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add large v3
This commit is contained in:
parent
13f1142eaa
commit
94705ed38b
Binary file not shown.
@ -11,7 +11,6 @@ import numpy as np
|
|||||||
# hard-coded audio hyperparameters
|
# hard-coded audio hyperparameters
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
N_FFT = 400
|
N_FFT = 400
|
||||||
N_MELS = 80
|
|
||||||
HOP_LENGTH = 160
|
HOP_LENGTH = 160
|
||||||
CHUNK_LENGTH = 30
|
CHUNK_LENGTH = 30
|
||||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
@ -81,7 +80,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def mel_filters(n_mels: int = N_MELS) -> mx.array:
|
def mel_filters(n_mels: int) -> mx.array:
|
||||||
"""
|
"""
|
||||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
Allows decoupling librosa dependency; saved using:
|
Allows decoupling librosa dependency; saved using:
|
||||||
@ -89,9 +88,10 @@ def mel_filters(n_mels: int = N_MELS) -> mx.array:
|
|||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
"mel_filters.npz",
|
"mel_filters.npz",
|
||||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||||
|
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||||
|
|
||||||
filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||||
return mx.load(filename)[f"mel_{n_mels}"]
|
return mx.load(filename)[f"mel_{n_mels}"]
|
||||||
@ -130,7 +130,7 @@ def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="re
|
|||||||
|
|
||||||
def log_mel_spectrogram(
|
def log_mel_spectrogram(
|
||||||
audio: Union[str, np.ndarray],
|
audio: Union[str, np.ndarray],
|
||||||
n_mels: int = N_MELS,
|
n_mels: int = 80,
|
||||||
padding: int = 0,
|
padding: int = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -33,7 +33,9 @@ def detect_language(
|
|||||||
list of dictionaries containing the probability distribution over all languages.
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
"""
|
"""
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = get_tokenizer(model.is_multilingual)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
tokenizer.language is None
|
tokenizer.language is None
|
||||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
@ -401,7 +403,10 @@ class DecodingTask:
|
|||||||
|
|
||||||
language = options.language or "en"
|
language = options.language or "en"
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
model.is_multilingual, language=language, task=options.task
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
)
|
)
|
||||||
self.tokenizer: Tokenizer = tokenizer
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
self.options: DecodingOptions = self._verify_options(options)
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
|
@ -25,7 +25,8 @@ _MODELS = {
|
|||||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
@ -41,7 +42,8 @@ _ALIGNMENT_HEADS = {
|
|||||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,6 +109,7 @@ LANGUAGES = {
|
|||||||
"ba": "bashkir",
|
"ba": "bashkir",
|
||||||
"jw": "javanese",
|
"jw": "javanese",
|
||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
|
"yue": "cantonese",
|
||||||
}
|
}
|
||||||
|
|
||||||
# language code lookup by name, with a few language aliases
|
# language code lookup by name, with a few language aliases
|
||||||
@ -125,6 +126,7 @@ TO_LANGUAGE_CODE = {
|
|||||||
"moldovan": "ro",
|
"moldovan": "ro",
|
||||||
"sinhalese": "si",
|
"sinhalese": "si",
|
||||||
"castilian": "es",
|
"castilian": "es",
|
||||||
|
"mandarin": "zh",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -133,6 +135,7 @@ class Tokenizer:
|
|||||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||||
|
|
||||||
encoding: tiktoken.Encoding
|
encoding: tiktoken.Encoding
|
||||||
|
num_languages: int
|
||||||
language: Optional[str] = None
|
language: Optional[str] = None
|
||||||
task: Optional[str] = None
|
task: Optional[str] = None
|
||||||
sot_sequence: Tuple[int] = ()
|
sot_sequence: Tuple[int] = ()
|
||||||
@ -147,7 +150,7 @@ class Tokenizer:
|
|||||||
translate: int = self.special_tokens["<|translate|>"]
|
translate: int = self.special_tokens["<|translate|>"]
|
||||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||||
|
|
||||||
langs = tuple(LANGUAGES.keys())
|
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||||
sot_sequence = [sot]
|
sot_sequence = [sot]
|
||||||
if self.language is not None:
|
if self.language is not None:
|
||||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||||
@ -213,10 +216,13 @@ class Tokenizer:
|
|||||||
if self.language is None:
|
if self.language is None:
|
||||||
raise ValueError("This tokenizer does not have language token configured")
|
raise ValueError("This tokenizer does not have language token configured")
|
||||||
|
|
||||||
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
return self.to_language_token(self.language)
|
||||||
|
|
||||||
|
def to_language_token(self, language):
|
||||||
|
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||||
return token
|
return token
|
||||||
|
|
||||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_tokens(self) -> Tuple[int]:
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
@ -224,7 +230,7 @@ class Tokenizer:
|
|||||||
for token, token_id in self.special_tokens.items():
|
for token, token_id in self.special_tokens.items():
|
||||||
if token.strip("<|>") in LANGUAGES:
|
if token.strip("<|>") in LANGUAGES:
|
||||||
result.append(token_id)
|
result.append(token_id)
|
||||||
return tuple(result)
|
return tuple(result)[: self.num_languages]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def all_language_codes(self) -> Tuple[str]:
|
def all_language_codes(self) -> Tuple[str]:
|
||||||
@ -271,7 +277,7 @@ class Tokenizer:
|
|||||||
return tuple(sorted(result))
|
return tuple(sorted(result))
|
||||||
|
|
||||||
def split_to_word_tokens(self, tokens: List[int]):
|
def split_to_word_tokens(self, tokens: List[int]):
|
||||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||||
# These languages don't typically use spaces, so it is difficult to split words
|
# These languages don't typically use spaces, so it is difficult to split words
|
||||||
# without morpheme analysis. Here, we instead split words at any
|
# without morpheme analysis. Here, we instead split words at any
|
||||||
# position where the tokens are decoded as valid unicode points
|
# position where the tokens are decoded as valid unicode points
|
||||||
@ -324,7 +330,7 @@ class Tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_encoding(name: str = "gpt2"):
|
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||||
with open(vocab_path) as fid:
|
with open(vocab_path) as fid:
|
||||||
ranks = {
|
ranks = {
|
||||||
@ -337,7 +343,7 @@ def get_encoding(name: str = "gpt2"):
|
|||||||
specials = [
|
specials = [
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
"<|startoftranscript|>",
|
"<|startoftranscript|>",
|
||||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||||
"<|translate|>",
|
"<|translate|>",
|
||||||
"<|transcribe|>",
|
"<|transcribe|>",
|
||||||
"<|startoflm|>",
|
"<|startoflm|>",
|
||||||
@ -364,6 +370,7 @@ def get_encoding(name: str = "gpt2"):
|
|||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
multilingual: bool,
|
multilingual: bool,
|
||||||
*,
|
*,
|
||||||
|
num_languages: int = 99,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
) -> Tokenizer:
|
) -> Tokenizer:
|
||||||
@ -384,6 +391,8 @@ def get_tokenizer(
|
|||||||
language = None
|
language = None
|
||||||
task = None
|
task = None
|
||||||
|
|
||||||
encoding = get_encoding(name=encoding_name)
|
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||||
|
|
||||||
return Tokenizer(encoding=encoding, language=language, task=task)
|
return Tokenizer(
|
||||||
|
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||||
|
)
|
||||||
|
@ -234,7 +234,8 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_text_head,
|
self.dims.n_text_head,
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
)
|
)
|
||||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
all_heads = torch.zeros(
|
all_heads = torch.zeros(
|
||||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
)
|
)
|
||||||
@ -267,7 +268,11 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multilingual(self):
|
def is_multilingual(self):
|
||||||
return self.dims.n_vocab == 51865
|
return self.dims.n_vocab >= 51865
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_languages(self):
|
||||||
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
|
@ -119,7 +119,7 @@ def transcribe(
|
|||||||
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
|
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
|
||||||
content_frames = mel.shape[-2] - N_FRAMES
|
content_frames = mel.shape[-2] - N_FRAMES
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -150,7 +150,12 @@ def transcribe(
|
|||||||
|
|
||||||
language: str = decode_options["language"]
|
language: str = decode_options["language"]
|
||||||
task: str = decode_options.get("task", "transcribe")
|
task: str = decode_options.get("task", "transcribe")
|
||||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
def decode_with_fallback(segment: mx.array) -> DecodingResult:
|
def decode_with_fallback(segment: mx.array) -> DecodingResult:
|
||||||
temperatures = (
|
temperatures = (
|
||||||
|
@ -210,7 +210,11 @@ class Whisper(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multilingual(self):
|
def is_multilingual(self):
|
||||||
return self.dims.n_vocab == 51865
|
return self.dims.n_vocab >= 51865
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_languages(self):
|
||||||
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
detect_language = detect_language_function
|
detect_language = detect_language_function
|
||||||
decode = decode_function
|
decode = decode_function
|
||||||
|
Loading…
Reference in New Issue
Block a user