mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
whisper default in fp16
This commit is contained in:
parent
13f1142eaa
commit
6e723a015a
@ -57,12 +57,13 @@ if __name__ == "__main__":
|
|||||||
if sys.argv[1] == "--all":
|
if sys.argv[1] == "--all":
|
||||||
models = ["tiny", "small", "medium", "large"]
|
models = ["tiny", "small", "medium", "large"]
|
||||||
|
|
||||||
|
feat_time = timer(feats)
|
||||||
|
print(f"\nFeature time {feat_time:.3f}")
|
||||||
|
mels = feats()[None].astype(mx.float16)
|
||||||
|
|
||||||
for model_name in models:
|
for model_name in models:
|
||||||
feat_time = timer(feats)
|
|
||||||
|
|
||||||
print(f"\nModel: {model_name.upper()}")
|
print(f"\nModel: {model_name.upper()}")
|
||||||
print(f"\nFeature time {feat_time:.3f}")
|
|
||||||
mels = feats()[None]
|
|
||||||
tokens = mx.array(
|
tokens = mx.array(
|
||||||
[
|
[
|
||||||
50364,
|
50364,
|
||||||
@ -96,7 +97,7 @@ if __name__ == "__main__":
|
|||||||
],
|
],
|
||||||
mx.int32,
|
mx.int32,
|
||||||
)[None]
|
)[None]
|
||||||
model = load_models.load_model(f"{model_name}")
|
model = load_models.load_model(f"{model_name}", dtype=mx.float16)
|
||||||
model_forward_time = timer(model_forward, model, mels, tokens)
|
model_forward_time = timer(model_forward, model, mels, tokens)
|
||||||
print(f"Model forward time {model_forward_time:.3f}")
|
print(f"Model forward time {model_forward_time:.3f}")
|
||||||
decode_time = timer(decode, model, mels)
|
decode_time = timer(decode, model, mels)
|
||||||
|
@ -36,7 +36,7 @@ def forward_mlx(model, mels, tokens):
|
|||||||
class TestWhisper(unittest.TestCase):
|
class TestWhisper(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = load_models.load_model("tiny")
|
cls.model = load_models.load_model("tiny", dtype=mx.float32)
|
||||||
data = audio.load_audio(TEST_AUDIO)
|
data = audio.load_audio(TEST_AUDIO)
|
||||||
data = audio.pad_or_trim(data)
|
data = audio.pad_or_trim(data)
|
||||||
cls.mels = audio.log_mel_spectrogram(data)
|
cls.mels = audio.log_mel_spectrogram(data)
|
||||||
@ -52,13 +52,22 @@ class TestWhisper(unittest.TestCase):
|
|||||||
|
|
||||||
torch_logits = forward_torch(torch_model, mels, tokens)
|
torch_logits = forward_torch(torch_model, mels, tokens)
|
||||||
|
|
||||||
mlx_model = load_models.torch_to_mlx(torch_model)
|
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
|
||||||
mlx_logits = forward_mlx(mlx_model, mels, tokens)
|
mlx_logits = forward_mlx(mlx_model, mels, tokens)
|
||||||
|
|
||||||
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
||||||
|
|
||||||
|
def test_fp16(self):
|
||||||
|
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
|
||||||
|
dims = mlx_model.dims
|
||||||
|
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
|
||||||
|
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
|
||||||
|
logits = mlx_model(mels, tokens)
|
||||||
|
self.assertEqual(logits.dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
def test_decode_lang(self):
|
def test_decode_lang(self):
|
||||||
options = decoding.DecodingOptions(task="lang_id")
|
options = decoding.DecodingOptions(task="lang_id", fp16=False)
|
||||||
result = decoding.decode(self.model, self.mels, options)
|
result = decoding.decode(self.model, self.mels, options)
|
||||||
self.assertEqual(result.language, "en")
|
self.assertEqual(result.language, "en")
|
||||||
self.assertEqual(len(result.language_probs), 99)
|
self.assertEqual(len(result.language_probs), 99)
|
||||||
@ -67,7 +76,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_decode_greedy(self):
|
def test_decode_greedy(self):
|
||||||
result = decoding.decode(self.model, self.mels)
|
result = decoding.decode(self.model, self.mels, fp16=False)
|
||||||
self.assertEqual(result.language, "en")
|
self.assertEqual(result.language, "en")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.tokens,
|
result.tokens,
|
||||||
@ -114,7 +123,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||||
|
|
||||||
# Small temp should give the same results
|
# Small temp should give the same results
|
||||||
result = decoding.decode(self.model, self.mels, temperature=1e-8)
|
result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result.text,
|
result.text,
|
||||||
@ -128,7 +137,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||||
|
|
||||||
def test_transcribe(self):
|
def test_transcribe(self):
|
||||||
result = whisper.transcribe(TEST_AUDIO)
|
result = whisper.transcribe(TEST_AUDIO, fp16=False)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result["text"],
|
result["text"],
|
||||||
(
|
(
|
||||||
@ -147,7 +156,7 @@ class TestWhisper(unittest.TestCase):
|
|||||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||||
return
|
return
|
||||||
|
|
||||||
result = whisper.transcribe(audio_file)
|
result = whisper.transcribe(audio_file, fp16=False)
|
||||||
self.assertEqual(len(result["text"]), 10920)
|
self.assertEqual(len(result["text"]), 10920)
|
||||||
self.assertEqual(result["language"], "en")
|
self.assertEqual(result["language"], "en")
|
||||||
self.assertEqual(len(result["segments"]), 77)
|
self.assertEqual(len(result["segments"]), 77)
|
||||||
|
@ -110,7 +110,7 @@ class DecodingOptions:
|
|||||||
max_initial_timestamp: Optional[float] = 1.0
|
max_initial_timestamp: Optional[float] = 1.0
|
||||||
|
|
||||||
# implementation details
|
# implementation details
|
||||||
fp16: bool = False # use fp16 for most of the calculation
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -141,7 +141,7 @@ class Inference:
|
|||||||
logits, self.kv_cache = self.model.decoder(
|
logits, self.kv_cache = self.model.decoder(
|
||||||
tokens, audio_features, kv_cache=self.kv_cache
|
tokens, audio_features, kv_cache=self.kv_cache
|
||||||
)
|
)
|
||||||
return logits
|
return logits.astype(mx.float32)
|
||||||
|
|
||||||
def rearrange_kv_cache(self, source_indices):
|
def rearrange_kv_cache(self, source_indices):
|
||||||
"""Update the key-value cache according to the updated beams"""
|
"""Update the key-value cache according to the updated beams"""
|
||||||
@ -542,7 +542,7 @@ class DecodingTask:
|
|||||||
audio_features = self.model.encoder(mel)
|
audio_features = self.model.encoder(mel)
|
||||||
|
|
||||||
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
|
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
|
||||||
return TypeError(
|
raise TypeError(
|
||||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import warnings
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -163,7 +164,7 @@ def convert(model, rules=None):
|
|||||||
|
|
||||||
|
|
||||||
def torch_to_mlx(
|
def torch_to_mlx(
|
||||||
torch_model: torch_whisper.Whisper,
|
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
|
||||||
) -> whisper.Whisper:
|
) -> whisper.Whisper:
|
||||||
def convert_rblock(model, rules):
|
def convert_rblock(model, rules):
|
||||||
children = dict(model.named_children())
|
children = dict(model.named_children())
|
||||||
@ -182,7 +183,8 @@ def torch_to_mlx(
|
|||||||
|
|
||||||
params = convert(torch_model, rules)
|
params = convert(torch_model, rules)
|
||||||
|
|
||||||
mlx_model = whisper.Whisper(torch_model.dims)
|
mlx_model = whisper.Whisper(torch_model.dims, dtype)
|
||||||
|
params = tree_map(lambda p: p.astype(dtype), params)
|
||||||
mlx_model.update(params)
|
mlx_model.update(params)
|
||||||
return mlx_model
|
return mlx_model
|
||||||
|
|
||||||
@ -190,5 +192,6 @@ def torch_to_mlx(
|
|||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name: str,
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
|
dtype : mx.Dtype = mx.float32,
|
||||||
) -> whisper.Whisper:
|
) -> whisper.Whisper:
|
||||||
return torch_to_mlx(load_torch_model(name, download_root))
|
return torch_to_mlx(load_torch_model(name, download_root), dtype)
|
||||||
|
@ -43,9 +43,9 @@ class ModelHolder:
|
|||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model(cls, model: str):
|
def get_model(cls, model: str, dtype : mx.Dtype):
|
||||||
if cls.model is None or model != cls.model_name:
|
if cls.model is None or model != cls.model_name:
|
||||||
cls.model = load_model(model)
|
cls.model = load_model(model, dtype=dtype)
|
||||||
cls.model_name = model
|
cls.model_name = model
|
||||||
return cls.model
|
return cls.model
|
||||||
|
|
||||||
@ -114,9 +114,8 @@ def transcribe(
|
|||||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = ModelHolder.get_model(model)
|
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
||||||
|
model = ModelHolder.get_model(model, dtype)
|
||||||
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
|
|
||||||
|
|
||||||
# Pad 30-seconds of silence to the input audio, for slicing
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||||
|
@ -37,6 +37,10 @@ def sinusoids(length, channels, max_timescale=10000):
|
|||||||
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
|
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
|
||||||
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
@ -94,17 +98,17 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(n_state, n_head)
|
self.attn = MultiHeadAttention(n_state, n_head)
|
||||||
self.attn_ln = nn.LayerNorm(n_state)
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
self.cross_attn = (
|
self.cross_attn = (
|
||||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||||
)
|
)
|
||||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
n_mlp = n_state * 4
|
n_mlp = n_state * 4
|
||||||
self.mlp1 = nn.Linear(n_state, n_mlp)
|
self.mlp1 = nn.Linear(n_state, n_mlp)
|
||||||
self.mlp2 = nn.Linear(n_mlp, n_state)
|
self.mlp2 = nn.Linear(n_mlp, n_state)
|
||||||
self.mlp_ln = nn.LayerNorm(n_state)
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
||||||
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
||||||
@ -119,15 +123,15 @@ class ResidualAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
self._positional_embedding = sinusoids(n_ctx, n_state)
|
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
|
||||||
|
|
||||||
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||||
self.ln_post = nn.LayerNorm(n_state)
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = nn.gelu(self.conv1(x))
|
x = nn.gelu(self.conv1(x))
|
||||||
@ -144,7 +148,7 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
class TextDecoder(nn.Module):
|
class TextDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -155,8 +159,8 @@ class TextDecoder(nn.Module):
|
|||||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||||
for _ in range(n_layer)
|
for _ in range(n_layer)
|
||||||
]
|
]
|
||||||
self.ln = nn.LayerNorm(n_state)
|
self.ln = LayerNorm(n_state)
|
||||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx)
|
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
|
||||||
|
|
||||||
def __call__(self, x, xa, kv_cache=None):
|
def __call__(self, x, xa, kv_cache=None):
|
||||||
"""
|
"""
|
||||||
@ -181,7 +185,7 @@ class TextDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Whisper(nn.Module):
|
class Whisper(nn.Module):
|
||||||
def __init__(self, dims: ModelDimensions):
|
def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.encoder = AudioEncoder(
|
self.encoder = AudioEncoder(
|
||||||
@ -190,6 +194,7 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_audio_state,
|
self.dims.n_audio_state,
|
||||||
self.dims.n_audio_head,
|
self.dims.n_audio_head,
|
||||||
self.dims.n_audio_layer,
|
self.dims.n_audio_layer,
|
||||||
|
dtype,
|
||||||
)
|
)
|
||||||
self.decoder = TextDecoder(
|
self.decoder = TextDecoder(
|
||||||
self.dims.n_vocab,
|
self.dims.n_vocab,
|
||||||
@ -197,6 +202,7 @@ class Whisper(nn.Module):
|
|||||||
self.dims.n_text_state,
|
self.dims.n_text_state,
|
||||||
self.dims.n_text_head,
|
self.dims.n_text_head,
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
|
dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed_audio(self, mel):
|
def embed_audio(self, mel):
|
||||||
|
Loading…
Reference in New Issue
Block a user