From 6e723a015a5b9d8f39d64a8e2788d87709c45123 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 12 Dec 2023 07:37:35 -0800 Subject: [PATCH] whisper default in fp16 --- whisper/benchmark.py | 9 +++++---- whisper/test.py | 23 ++++++++++++++++------- whisper/whisper/decoding.py | 6 +++--- whisper/whisper/load_models.py | 9 ++++++--- whisper/whisper/transcribe.py | 9 ++++----- whisper/whisper/whisper.py | 26 ++++++++++++++++---------- 6 files changed, 50 insertions(+), 32 deletions(-) diff --git a/whisper/benchmark.py b/whisper/benchmark.py index 9df6b500..228a3b36 100644 --- a/whisper/benchmark.py +++ b/whisper/benchmark.py @@ -57,12 +57,13 @@ if __name__ == "__main__": if sys.argv[1] == "--all": models = ["tiny", "small", "medium", "large"] + feat_time = timer(feats) + print(f"\nFeature time {feat_time:.3f}") + mels = feats()[None].astype(mx.float16) + for model_name in models: - feat_time = timer(feats) print(f"\nModel: {model_name.upper()}") - print(f"\nFeature time {feat_time:.3f}") - mels = feats()[None] tokens = mx.array( [ 50364, @@ -96,7 +97,7 @@ if __name__ == "__main__": ], mx.int32, )[None] - model = load_models.load_model(f"{model_name}") + model = load_models.load_model(f"{model_name}", dtype=mx.float16) model_forward_time = timer(model_forward, model, mels, tokens) print(f"Model forward time {model_forward_time:.3f}") decode_time = timer(decode, model, mels) diff --git a/whisper/test.py b/whisper/test.py index 44f99edf..79f233ba 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -36,7 +36,7 @@ def forward_mlx(model, mels, tokens): class TestWhisper(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = load_models.load_model("tiny") + cls.model = load_models.load_model("tiny", dtype=mx.float32) data = audio.load_audio(TEST_AUDIO) data = audio.pad_or_trim(data) cls.mels = audio.log_mel_spectrogram(data) @@ -52,13 +52,22 @@ class TestWhisper(unittest.TestCase): torch_logits = forward_torch(torch_model, mels, tokens) - mlx_model = load_models.torch_to_mlx(torch_model) + mlx_model = load_models.torch_to_mlx(torch_model, mx.float32) mlx_logits = forward_mlx(mlx_model, mels, tokens) self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2)) + def test_fp16(self): + mlx_model = load_models.load_model("tiny", dtype=mx.float16) + dims = mlx_model.dims + mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16) + tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32) + logits = mlx_model(mels, tokens) + self.assertEqual(logits.dtype, mx.float16) + + def test_decode_lang(self): - options = decoding.DecodingOptions(task="lang_id") + options = decoding.DecodingOptions(task="lang_id", fp16=False) result = decoding.decode(self.model, self.mels, options) self.assertEqual(result.language, "en") self.assertEqual(len(result.language_probs), 99) @@ -67,7 +76,7 @@ class TestWhisper(unittest.TestCase): ) def test_decode_greedy(self): - result = decoding.decode(self.model, self.mels) + result = decoding.decode(self.model, self.mels, fp16=False) self.assertEqual(result.language, "en") self.assertEqual( result.tokens, @@ -114,7 +123,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) # Small temp should give the same results - result = decoding.decode(self.model, self.mels, temperature=1e-8) + result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False) self.assertEqual( result.text, @@ -128,7 +137,7 @@ class TestWhisper(unittest.TestCase): self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752) def test_transcribe(self): - result = whisper.transcribe(TEST_AUDIO) + result = whisper.transcribe(TEST_AUDIO, fp16=False) self.assertEqual( result["text"], ( @@ -147,7 +156,7 @@ class TestWhisper(unittest.TestCase): print("bash path_to_whisper_repo/whisper/assets/download_alice.sh") return - result = whisper.transcribe(audio_file) + result = whisper.transcribe(audio_file, fp16=False) self.assertEqual(len(result["text"]), 10920) self.assertEqual(result["language"], "en") self.assertEqual(len(result["segments"]), 77) diff --git a/whisper/whisper/decoding.py b/whisper/whisper/decoding.py index c4b6326d..d63d5e98 100644 --- a/whisper/whisper/decoding.py +++ b/whisper/whisper/decoding.py @@ -110,7 +110,7 @@ class DecodingOptions: max_initial_timestamp: Optional[float] = 1.0 # implementation details - fp16: bool = False # use fp16 for most of the calculation + fp16: bool = True # use fp16 for most of the calculation @dataclass(frozen=True) @@ -141,7 +141,7 @@ class Inference: logits, self.kv_cache = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) - return logits + return logits.astype(mx.float32) def rearrange_kv_cache(self, source_indices): """Update the key-value cache according to the updated beams""" @@ -542,7 +542,7 @@ class DecodingTask: audio_features = self.model.encoder(mel) if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32): - return TypeError( + raise TypeError( f"audio_features has an incorrect dtype: {audio_features.dtype}" ) diff --git a/whisper/whisper/load_models.py b/whisper/whisper/load_models.py index 2d0ae578..6a4e301b 100644 --- a/whisper/whisper/load_models.py +++ b/whisper/whisper/load_models.py @@ -7,6 +7,7 @@ import warnings from typing import List import mlx.core as mx +from mlx.utils import tree_map import torch from tqdm import tqdm @@ -163,7 +164,7 @@ def convert(model, rules=None): def torch_to_mlx( - torch_model: torch_whisper.Whisper, + torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16, ) -> whisper.Whisper: def convert_rblock(model, rules): children = dict(model.named_children()) @@ -182,7 +183,8 @@ def torch_to_mlx( params = convert(torch_model, rules) - mlx_model = whisper.Whisper(torch_model.dims) + mlx_model = whisper.Whisper(torch_model.dims, dtype) + params = tree_map(lambda p: p.astype(dtype), params) mlx_model.update(params) return mlx_model @@ -190,5 +192,6 @@ def torch_to_mlx( def load_model( name: str, download_root: str = None, + dtype : mx.Dtype = mx.float32, ) -> whisper.Whisper: - return torch_to_mlx(load_torch_model(name, download_root)) + return torch_to_mlx(load_torch_model(name, download_root), dtype) diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index bfdc32b5..f05b828c 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -43,9 +43,9 @@ class ModelHolder: model_name = None @classmethod - def get_model(cls, model: str): + def get_model(cls, model: str, dtype : mx.Dtype): if cls.model is None or model != cls.model_name: - cls.model = load_model(model) + cls.model = load_model(model, dtype=dtype) cls.model_name = model return cls.model @@ -114,9 +114,8 @@ def transcribe( the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ - model = ModelHolder.get_model(model) - - dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32 + dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 + model = ModelHolder.get_model(model, dtype) # Pad 30-seconds of silence to the input audio, for slicing mel = log_mel_spectrogram(audio, padding=N_SAMPLES) diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index ec60c6ec..1c7b856f 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -37,6 +37,10 @@ def sinusoids(length, channels, max_timescale=10000): scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) +class LayerNorm(nn.LayerNorm): + def __call__(self, x: mx.array) -> mx.array: + return super().__call__(x.astype(mx.float32)).astype(x.dtype) + class MultiHeadAttention(nn.Module): def __init__(self, n_state: int, n_head: int): @@ -94,17 +98,17 @@ class ResidualAttentionBlock(nn.Module): super().__init__() self.attn = MultiHeadAttention(n_state, n_head) - self.attn_ln = nn.LayerNorm(n_state) + self.attn_ln = LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head) if cross_attention else None ) - self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 self.mlp1 = nn.Linear(n_state, n_mlp) self.mlp2 = nn.Linear(n_mlp, n_state) - self.mlp_ln = nn.LayerNorm(n_state) + self.mlp_ln = LayerNorm(n_state) def __call__(self, x, xa=None, mask=None, kv_cache=None): kv, cross_kv = kv_cache if kv_cache else (None, None) @@ -119,15 +123,15 @@ class ResidualAttentionBlock(nn.Module): class AudioEncoder(nn.Module): def __init__( - self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, ): super().__init__() self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) - self._positional_embedding = sinusoids(n_ctx, n_state) + self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype) self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] - self.ln_post = nn.LayerNorm(n_state) + self.ln_post = LayerNorm(n_state) def __call__(self, x): x = nn.gelu(self.conv1(x)) @@ -144,7 +148,7 @@ class AudioEncoder(nn.Module): class TextDecoder(nn.Module): def __init__( - self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16, ): super().__init__() @@ -155,8 +159,8 @@ class TextDecoder(nn.Module): ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer) ] - self.ln = nn.LayerNorm(n_state) - self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx) + self.ln = LayerNorm(n_state) + self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) def __call__(self, x, xa, kv_cache=None): """ @@ -181,7 +185,7 @@ class TextDecoder(nn.Module): class Whisper(nn.Module): - def __init__(self, dims: ModelDimensions): + def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16): super().__init__() self.dims = dims self.encoder = AudioEncoder( @@ -190,6 +194,7 @@ class Whisper(nn.Module): self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, + dtype, ) self.decoder = TextDecoder( self.dims.n_vocab, @@ -197,6 +202,7 @@ class Whisper(nn.Module): self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer, + dtype, ) def embed_audio(self, mel):