mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Whisper improvements (#1080)
* use safetensors in whisper * speed up decoder * version
This commit is contained in:
parent
85ffd2c96a
commit
8160e0c4e5
@ -181,7 +181,7 @@ def load_torch_weights_and_config(
|
||||
)
|
||||
|
||||
if name_or_path.endswith(".pt"):
|
||||
checkpoint = torch.load(name_or_path, map_location="cpu")
|
||||
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
|
||||
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
||||
else:
|
||||
name_or_path = Path(name_or_path)
|
||||
@ -387,7 +387,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Save weights
|
||||
print("[INFO] Saving")
|
||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
||||
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.3.0"
|
||||
__version__ = "0.4.0"
|
||||
|
@ -58,11 +58,12 @@ def detect_language(
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32)
|
||||
mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
|
||||
mask[list(tokenizer.all_language_tokens)] = 0.0
|
||||
logits += mx.array(mask)
|
||||
logits += mask
|
||||
language_tokens = mx.argmax(logits, axis=-1)
|
||||
language_token_probs = mx.softmax(logits, axis=-1)
|
||||
language_token_probs = np.array(language_token_probs)
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
@ -129,17 +130,12 @@ class DecodingResult:
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
def __init__(self, model: "Whisper"):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = None
|
||||
|
||||
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
logits, self.kv_cache, _ = self.model.decoder(
|
||||
tokens, audio_features, kv_cache=self.kv_cache
|
||||
)
|
||||
@ -251,6 +247,11 @@ class TokenDecoder:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@mx.compile
|
||||
def categorical(logits, temp):
|
||||
return mx.random.categorical(logits / temp)
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder):
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(axis=-1)
|
||||
else:
|
||||
next_tokens = mx.random.categorical(logits=logits / self.temperature)
|
||||
next_tokens = categorical(logits, self.temperature)
|
||||
|
||||
next_tokens = mx.argmax(logits, axis=-1)
|
||||
logits = logits.astype(mx.float32)
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
||||
|
||||
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
||||
@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder):
|
||||
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
return tokens, sum_logprobs
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter):
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
mask[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = sampled_tokens.tolist()
|
||||
## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
tokens = tokens.tolist()
|
||||
for k in range(len(tokens)):
|
||||
seq = tokens[k][self.sample_begin :]
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter):
|
||||
last_timestamp += 1
|
||||
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
if len(tokens[0]) == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter):
|
||||
mask[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
mask = mx.array(mask)
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
axis=-1
|
||||
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
axis=-1, keepdims=True
|
||||
)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
mask[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
return logits + mx.array(mask, logits.dtype)
|
||||
max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
|
||||
axis=-1, keepdims=True
|
||||
)
|
||||
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
|
||||
timestamp_logprob > max_text_token_logprob,
|
||||
-mx.inf,
|
||||
mask[:, : self.tokenizer.timestamp_begin],
|
||||
)
|
||||
return logits + mask
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
@ -424,7 +427,7 @@ class DecodingTask:
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = Inference(model, len(self.initial_tokens))
|
||||
self.inference = Inference(model)
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
@ -432,9 +435,6 @@ class DecodingTask:
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
raise NotImplementedError("Beam search decoder is not yet implemented")
|
||||
# self.decoder = BeamSearchDecoder(
|
||||
# options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
# )
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
@ -448,6 +448,7 @@ class DecodingTask:
|
||||
self.logit_filters.append(
|
||||
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
|
||||
)
|
||||
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
@ -570,23 +571,13 @@ class DecodingTask:
|
||||
|
||||
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: mx.array = mx.zeros(n_batch)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
sum_logprobs = mx.zeros(n_batch)
|
||||
|
||||
try:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
def _step(inputs, audio_features, tokens, sum_logprobs):
|
||||
pre_logits = self.inference.logits(inputs, audio_features)
|
||||
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
probs_at_sot = mx.softmax(
|
||||
logits[:, self.sot_index].astype(mx.float32), axis=-1
|
||||
)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
# consider the logits at the last token only
|
||||
logits = pre_logits[:, -1]
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
@ -596,9 +587,31 @@ class DecodingTask:
|
||||
tokens, completed, sum_logprobs = self.decoder.update(
|
||||
tokens, logits, sum_logprobs
|
||||
)
|
||||
return tokens, completed, sum_logprobs, pre_logits
|
||||
|
||||
try:
|
||||
tokens, completed, sum_logprobs, pre_logits = _step(
|
||||
tokens, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
if self.tokenizer.no_speech is not None: # compute no_speech_probs
|
||||
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
|
||||
else:
|
||||
no_speech_probs = mx.full(n_batch, mx.nan)
|
||||
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
|
||||
|
||||
for i in range(1, self.sample_len):
|
||||
inputs = tokens[:, -1:]
|
||||
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||
inputs, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
tokens = next_tokens
|
||||
completed = next_completed
|
||||
sum_logprobs = next_sum_logprobs
|
||||
|
||||
finally:
|
||||
self.inference.reset()
|
||||
|
||||
@ -610,8 +623,8 @@ class DecodingTask:
|
||||
n_audio: int = mel.shape[0]
|
||||
|
||||
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
|
||||
tokens: np.array = np.array(self.initial_tokens)
|
||||
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy()
|
||||
tokens: mx.array = mx.array(self.initial_tokens)
|
||||
tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
@ -626,7 +639,6 @@ class DecodingTask:
|
||||
]
|
||||
|
||||
# repeat tokens by the group size, for beam search or best-of-n sampling
|
||||
tokens = mx.array(tokens)
|
||||
if self.n_group > 1:
|
||||
tokens = tokens[:, None, :]
|
||||
tokens = mx.broadcast_to(
|
||||
@ -649,7 +661,13 @@ class DecodingTask:
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens = tokens[..., self.sample_begin :].tolist()
|
||||
tokens = tokens[..., self.sample_begin :]
|
||||
|
||||
# eval and convert to list
|
||||
mx.eval(tokens, sum_logprobs, no_speech_probs)
|
||||
tokens = tokens.tolist()
|
||||
sum_logprobs = sum_logprobs.tolist()
|
||||
no_speech_probs = no_speech_probs.tolist()
|
||||
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
|
@ -26,7 +26,10 @@ def load_model(
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
weights = mx.load(str(model_path / "weights.npz"))
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
|
@ -293,6 +293,7 @@ def transcribe(
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
|
||||
tokens = np.array(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
|
@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module):
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.astype(mx.float32)
|
||||
|
||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
||||
w = mx.softmax(qk, axis=-1, precise=True)
|
||||
out = (w @ v).transpose(0, 2, 1, 3)
|
||||
out = out.reshape(n_batch, n_ctx, n_state)
|
||||
return out, qk
|
||||
return out, qk.astype(mx.float32)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user