Whisper improvements (#1080)

* use safetensors in whisper

* speed up decoder

* version
This commit is contained in:
Awni Hannun 2024-11-01 10:52:28 -07:00 committed by GitHub
parent 85ffd2c96a
commit 8160e0c4e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 85 additions and 64 deletions

View File

@ -181,7 +181,7 @@ def load_torch_weights_and_config(
) )
if name_or_path.endswith(".pt"): if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu") checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
weights, config = checkpoint["model_state_dict"], checkpoint["dims"] weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else: else:
name_or_path = Path(name_or_path) name_or_path = Path(name_or_path)
@ -387,7 +387,7 @@ if __name__ == "__main__":
# Save weights # Save weights
print("[INFO] Saving") print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights) mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
# Save config.json with model_type # Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f: with open(str(mlx_path / "config.json"), "w") as f:

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.3.0" __version__ = "0.4.0"

View File

@ -58,11 +58,12 @@ def detect_language(
logits = model.logits(x, mel)[:, 0] logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens # collect detected languages; suppress all non-language tokens
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
mask[list(tokenizer.all_language_tokens)] = 0.0 mask[list(tokenizer.all_language_tokens)] = 0.0
logits += mx.array(mask) logits += mask
language_tokens = mx.argmax(logits, axis=-1) language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1)
language_token_probs = np.array(language_token_probs)
language_probs = [ language_probs = [
{ {
c: language_token_probs[i, j].item() c: language_token_probs[i, j].item()
@ -129,17 +130,12 @@ class DecodingResult:
class Inference: class Inference:
def __init__(self, model: "Whisper", initial_token_length: int): def __init__(self, model: "Whisper"):
self.model: "Whisper" = model self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = None self.kv_cache = None
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
"""Perform a forward pass on the decoder and return per-token logits""" """Perform a forward pass on the decoder and return per-token logits"""
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
logits, self.kv_cache, _ = self.model.decoder( logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache tokens, audio_features, kv_cache=self.kv_cache
) )
@ -251,6 +247,11 @@ class TokenDecoder:
raise NotImplementedError raise NotImplementedError
@mx.compile
def categorical(logits, temp):
return mx.random.categorical(logits / temp)
class GreedyDecoder(TokenDecoder): class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int): def __init__(self, temperature: float, eot: int):
self.temperature = temperature self.temperature = temperature
@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder):
if self.temperature == 0: if self.temperature == 0:
next_tokens = logits.argmax(axis=-1) next_tokens = logits.argmax(axis=-1)
else: else:
next_tokens = mx.random.categorical(logits=logits / self.temperature) next_tokens = categorical(logits, self.temperature)
next_tokens = mx.argmax(logits, axis=-1)
logits = logits.astype(mx.float32)
logprobs = logits - mx.logsumexp(logits, axis=-1) logprobs = logits - mx.logsumexp(logits, axis=-1)
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder):
def finalize(self, tokens: mx.array, sum_logprobs: mx.array): def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
# make sure each sequence has at least one EOT token at the end # make sure each sequence has at least one EOT token at the end
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
return tokens, sum_logprobs.tolist() return tokens, sum_logprobs
class LogitFilter: class LogitFilter:
@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter):
if self.tokenizer.no_timestamps is not None: if self.tokenizer.no_timestamps is not None:
mask[:, self.tokenizer.no_timestamps] = -np.inf mask[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]): tokens = tokens.tolist()
sampled_tokens = tokens[k, self.sample_begin :] for k in range(len(tokens)):
seq = sampled_tokens.tolist() seq = tokens[k][self.sample_begin :]
last_was_timestamp = ( last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
) )
@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter):
last_timestamp += 1 last_timestamp += 1
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
if tokens.shape[1] == self.sample_begin: if len(tokens[0]) == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning # suppress generating non-timestamp tokens at the beginning
mask[:, : self.tokenizer.timestamp_begin] = -np.inf mask[:, : self.tokenizer.timestamp_begin] = -np.inf
@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter):
mask[:, last_allowed + 1 :] = -np.inf mask[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp # if sum of probability over timestamps is above any other token, sample timestamp
mask = mx.array(mask)
logprobs = logits - mx.logsumexp(logits, axis=-1) logprobs = logits - mx.logsumexp(logits, axis=-1)
for k in range(tokens.shape[0]): timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( axis=-1, keepdims=True
axis=-1
) )
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
if timestamp_logprob > max_text_token_logprob: axis=-1, keepdims=True
mask[k, : self.tokenizer.timestamp_begin] = -np.inf )
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
return logits + mx.array(mask, logits.dtype) timestamp_logprob > max_text_token_logprob,
-mx.inf,
mask[:, : self.tokenizer.timestamp_begin],
)
return logits + mask
class DecodingTask: class DecodingTask:
@ -424,7 +427,7 @@ class DecodingTask:
self.sot_index: int = self.initial_tokens.index(tokenizer.sot) self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching # inference: implements the forward pass through the decoder, including kv caching
self.inference = Inference(model, len(self.initial_tokens)) self.inference = Inference(model)
# sequence ranker: implements how to rank a group of sampled sequences # sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
@ -432,9 +435,6 @@ class DecodingTask:
# decoder: implements how to select the next tokens, given the autoregressive distribution # decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None: if options.beam_size is not None:
raise NotImplementedError("Beam search decoder is not yet implemented") raise NotImplementedError("Beam search decoder is not yet implemented")
# self.decoder = BeamSearchDecoder(
# options.beam_size, tokenizer.eot, self.inference, options.patience
# )
else: else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
@ -448,6 +448,7 @@ class DecodingTask:
self.logit_filters.append( self.logit_filters.append(
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab) SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
) )
if not options.without_timestamps: if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None max_initial_timestamp_index = None
@ -570,23 +571,13 @@ class DecodingTask:
def _main_loop(self, audio_features: mx.array, tokens: mx.array): def _main_loop(self, audio_features: mx.array, tokens: mx.array):
n_batch = tokens.shape[0] n_batch = tokens.shape[0]
sum_logprobs: mx.array = mx.zeros(n_batch) sum_logprobs = mx.zeros(n_batch)
no_speech_probs = [np.nan] * n_batch
try: def _step(inputs, audio_features, tokens, sum_logprobs):
for i in range(self.sample_len): pre_logits = self.inference.logits(inputs, audio_features)
logits = self.inference.logits(tokens, audio_features)
if ( # consider the logits at the last token only
i == 0 and self.tokenizer.no_speech is not None logits = pre_logits[:, -1]
): # save no_speech_probs
probs_at_sot = mx.softmax(
logits[:, self.sot_index].astype(mx.float32), axis=-1
)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to # apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters: for logit_filter in self.logit_filters:
@ -596,9 +587,31 @@ class DecodingTask:
tokens, completed, sum_logprobs = self.decoder.update( tokens, completed, sum_logprobs = self.decoder.update(
tokens, logits, sum_logprobs tokens, logits, sum_logprobs
) )
return tokens, completed, sum_logprobs, pre_logits
try:
tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx: if completed or tokens.shape[-1] > self.n_ctx:
break break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs
finally: finally:
self.inference.reset() self.inference.reset()
@ -610,8 +623,8 @@ class DecodingTask:
n_audio: int = mel.shape[0] n_audio: int = mel.shape[0]
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
tokens: np.array = np.array(self.initial_tokens) tokens: mx.array = mx.array(self.initial_tokens)
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
# detect language if requested, overwriting the language token # detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens) languages, language_probs = self._detect_language(audio_features, tokens)
@ -626,7 +639,6 @@ class DecodingTask:
] ]
# repeat tokens by the group size, for beam search or best-of-n sampling # repeat tokens by the group size, for beam search or best-of-n sampling
tokens = mx.array(tokens)
if self.n_group > 1: if self.n_group > 1:
tokens = tokens[:, None, :] tokens = tokens[:, None, :]
tokens = mx.broadcast_to( tokens = mx.broadcast_to(
@ -649,7 +661,13 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT # get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens = tokens[..., self.sample_begin :].tolist() tokens = tokens[..., self.sample_begin :]
# eval and convert to list
mx.eval(tokens, sum_logprobs, no_speech_probs)
tokens = tokens.tolist()
sum_logprobs = sum_logprobs.tolist()
no_speech_probs = no_speech_probs.tolist()
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
# select the top-ranked sample in each group # select the top-ranked sample in each group

View File

@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config) model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz")) wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype) model = whisper.Whisper(model_args, dtype)

View File

@ -293,6 +293,7 @@ def transcribe(
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = np.array(result.tokens) tokens = np.array(result.tokens)
if no_speech_threshold is not None: if no_speech_threshold is not None:

View File

@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module):
qk = q @ k qk = q @ k
if mask is not None: if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx] qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype) w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3) out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state) out = out.reshape(n_batch, n_ctx, n_state)
return out, qk return out, qk.astype(mx.float32)
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):