From 8160e0c4e56df261d0c8406f68d40b42ef0a188b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Nov 2024 10:52:28 -0700 Subject: [PATCH] Whisper improvements (#1080) * use safetensors in whisper * speed up decoder * version --- whisper/convert.py | 4 +- whisper/mlx_whisper/_version.py | 2 +- whisper/mlx_whisper/decoding.py | 132 ++++++++++++++++------------- whisper/mlx_whisper/load_models.py | 5 +- whisper/mlx_whisper/transcribe.py | 1 + whisper/mlx_whisper/whisper.py | 5 +- 6 files changed, 85 insertions(+), 64 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index cdd50bc5..301fd5b4 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -181,7 +181,7 @@ def load_torch_weights_and_config( ) if name_or_path.endswith(".pt"): - checkpoint = torch.load(name_or_path, map_location="cpu") + checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) weights, config = checkpoint["model_state_dict"], checkpoint["dims"] else: name_or_path = Path(name_or_path) @@ -387,7 +387,7 @@ if __name__ == "__main__": # Save weights print("[INFO] Saving") - np.savez(str(mlx_path / "weights.npz"), **weights) + mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 67c7397c..45e522d1 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.3.0" +__version__ = "0.4.0" diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 41c2ec6d..6bf975d5 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -58,11 +58,12 @@ def detect_language( logits = model.logits(x, mel)[:, 0] # collect detected languages; suppress all non-language tokens - mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) + mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32) mask[list(tokenizer.all_language_tokens)] = 0.0 - logits += mx.array(mask) + logits += mask language_tokens = mx.argmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1) + language_token_probs = np.array(language_token_probs) language_probs = [ { c: language_token_probs[i, j].item() @@ -129,17 +130,12 @@ class DecodingResult: class Inference: - def __init__(self, model: "Whisper", initial_token_length: int): + def __init__(self, model: "Whisper"): self.model: "Whisper" = model - self.initial_token_length = initial_token_length self.kv_cache = None def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: """Perform a forward pass on the decoder and return per-token logits""" - if tokens.shape[-1] > self.initial_token_length: - # only need to use the last token except in the first forward pass - tokens = tokens[:, -1:] - logits, self.kv_cache, _ = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) @@ -251,6 +247,11 @@ class TokenDecoder: raise NotImplementedError +@mx.compile +def categorical(logits, temp): + return mx.random.categorical(logits / temp) + + class GreedyDecoder(TokenDecoder): def __init__(self, temperature: float, eot: int): self.temperature = temperature @@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder): if self.temperature == 0: next_tokens = logits.argmax(axis=-1) else: - next_tokens = mx.random.categorical(logits=logits / self.temperature) + next_tokens = categorical(logits, self.temperature) - next_tokens = mx.argmax(logits, axis=-1) - logits = logits.astype(mx.float32) logprobs = logits - mx.logsumexp(logits, axis=-1) current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] @@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder): def finalize(self, tokens: mx.array, sum_logprobs: mx.array): # make sure each sequence has at least one EOT token at the end tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) - return tokens, sum_logprobs.tolist() + return tokens, sum_logprobs class LogitFilter: @@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter): if self.tokenizer.no_timestamps is not None: mask[:, self.tokenizer.no_timestamps] = -np.inf - # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly - for k in range(tokens.shape[0]): - sampled_tokens = tokens[k, self.sample_begin :] - seq = sampled_tokens.tolist() + ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + tokens = tokens.tolist() + for k in range(len(tokens)): + seq = tokens[k][self.sample_begin :] last_was_timestamp = ( len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin ) @@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter): last_timestamp += 1 mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf - if tokens.shape[1] == self.sample_begin: + if len(tokens[0]) == self.sample_begin: # suppress generating non-timestamp tokens at the beginning mask[:, : self.tokenizer.timestamp_begin] = -np.inf @@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter): mask[:, last_allowed + 1 :] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp + mask = mx.array(mask) logprobs = logits - mx.logsumexp(logits, axis=-1) - for k in range(tokens.shape[0]): - timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( - axis=-1 - ) - max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() - if timestamp_logprob > max_text_token_logprob: - mask[k, : self.tokenizer.timestamp_begin] = -np.inf - - return logits + mx.array(mask, logits.dtype) + timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( + axis=-1, keepdims=True + ) + max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max( + axis=-1, keepdims=True + ) + mask[:, : self.tokenizer.timestamp_begin] = mx.where( + timestamp_logprob > max_text_token_logprob, + -mx.inf, + mask[:, : self.tokenizer.timestamp_begin], + ) + return logits + mask class DecodingTask: @@ -424,7 +427,7 @@ class DecodingTask: self.sot_index: int = self.initial_tokens.index(tokenizer.sot) # inference: implements the forward pass through the decoder, including kv caching - self.inference = Inference(model, len(self.initial_tokens)) + self.inference = Inference(model) # sequence ranker: implements how to rank a group of sampled sequences self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) @@ -432,9 +435,6 @@ class DecodingTask: # decoder: implements how to select the next tokens, given the autoregressive distribution if options.beam_size is not None: raise NotImplementedError("Beam search decoder is not yet implemented") - # self.decoder = BeamSearchDecoder( - # options.beam_size, tokenizer.eot, self.inference, options.patience - # ) else: self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) @@ -448,6 +448,7 @@ class DecodingTask: self.logit_filters.append( SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab) ) + if not options.without_timestamps: precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds max_initial_timestamp_index = None @@ -570,35 +571,47 @@ class DecodingTask: def _main_loop(self, audio_features: mx.array, tokens: mx.array): n_batch = tokens.shape[0] - sum_logprobs: mx.array = mx.zeros(n_batch) - no_speech_probs = [np.nan] * n_batch + sum_logprobs = mx.zeros(n_batch) + + def _step(inputs, audio_features, tokens, sum_logprobs): + pre_logits = self.inference.logits(inputs, audio_features) + + # consider the logits at the last token only + logits = pre_logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logits = logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed, sum_logprobs = self.decoder.update( + tokens, logits, sum_logprobs + ) + return tokens, completed, sum_logprobs, pre_logits try: - for i in range(self.sample_len): - logits = self.inference.logits(tokens, audio_features) + tokens, completed, sum_logprobs, pre_logits = _step( + tokens, audio_features, tokens, sum_logprobs + ) + if self.tokenizer.no_speech is not None: # compute no_speech_probs + probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] + else: + no_speech_probs = mx.full(n_batch, mx.nan) + mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) - if ( - i == 0 and self.tokenizer.no_speech is not None - ): # save no_speech_probs - probs_at_sot = mx.softmax( - logits[:, self.sot_index].astype(mx.float32), axis=-1 - ) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() - - # now we need to consider the logits at the last token only - logits = logits[:, -1] - - # apply the logit filters, e.g. for suppressing or applying penalty to - for logit_filter in self.logit_filters: - logits = logit_filter.apply(logits, tokens) - - # expand the tokens tensor with the selected next tokens - tokens, completed, sum_logprobs = self.decoder.update( - tokens, logits, sum_logprobs + for i in range(1, self.sample_len): + inputs = tokens[:, -1:] + next_tokens, next_completed, next_sum_logprobs, _ = _step( + inputs, audio_features, tokens, sum_logprobs ) - + mx.async_eval(next_completed, next_tokens, next_sum_logprobs) if completed or tokens.shape[-1] > self.n_ctx: break + tokens = next_tokens + completed = next_completed + sum_logprobs = next_sum_logprobs + finally: self.inference.reset() @@ -610,8 +623,8 @@ class DecodingTask: n_audio: int = mel.shape[0] audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass - tokens: np.array = np.array(self.initial_tokens) - tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() + tokens: mx.array = mx.array(self.initial_tokens) + tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens))) # detect language if requested, overwriting the language token languages, language_probs = self._detect_language(audio_features, tokens) @@ -626,7 +639,6 @@ class DecodingTask: ] # repeat tokens by the group size, for beam search or best-of-n sampling - tokens = mx.array(tokens) if self.n_group > 1: tokens = tokens[:, None, :] tokens = mx.broadcast_to( @@ -649,7 +661,13 @@ class DecodingTask: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) - tokens = tokens[..., self.sample_begin :].tolist() + tokens = tokens[..., self.sample_begin :] + + # eval and convert to list + mx.eval(tokens, sum_logprobs, no_speech_probs) + tokens = tokens.tolist() + sum_logprobs = sum_logprobs.tolist() + no_speech_probs = no_speech_probs.tolist() tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] # select the top-ranked sample in each group diff --git a/whisper/mlx_whisper/load_models.py b/whisper/mlx_whisper/load_models.py index 6705385d..60766ab2 100644 --- a/whisper/mlx_whisper/load_models.py +++ b/whisper/mlx_whisper/load_models.py @@ -26,7 +26,10 @@ def load_model( model_args = whisper.ModelDimensions(**config) - weights = mx.load(str(model_path / "weights.npz")) + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) model = whisper.Whisper(model_args, dtype) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 786b4232..7057679b 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -293,6 +293,7 @@ def transcribe( decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) + tokens = np.array(result.tokens) if no_speech_threshold is not None: diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index e691792c..1c2b390e 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.astype(mx.float32) - w = mx.softmax(qk, axis=-1).astype(q.dtype) + w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk + return out, qk.astype(mx.float32) class ResidualAttentionBlock(nn.Module):