whisper default in fp16

This commit is contained in:
Awni Hannun
2023-12-12 07:37:35 -08:00
parent 13f1142eaa
commit 6e723a015a
6 changed files with 50 additions and 32 deletions

View File

@@ -110,7 +110,7 @@ class DecodingOptions:
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = False # use fp16 for most of the calculation
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
@@ -141,7 +141,7 @@ class Inference:
logits, self.kv_cache = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits
return logits.astype(mx.float32)
def rearrange_kv_cache(self, source_indices):
"""Update the key-value cache according to the updated beams"""
@@ -542,7 +542,7 @@ class DecodingTask:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
return TypeError(
raise TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)

View File

@@ -7,6 +7,7 @@ import warnings
from typing import List
import mlx.core as mx
from mlx.utils import tree_map
import torch
from tqdm import tqdm
@@ -163,7 +164,7 @@ def convert(model, rules=None):
def torch_to_mlx(
torch_model: torch_whisper.Whisper,
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
@@ -182,7 +183,8 @@ def torch_to_mlx(
params = convert(torch_model, rules)
mlx_model = whisper.Whisper(torch_model.dims)
mlx_model = whisper.Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
return mlx_model
@@ -190,5 +192,6 @@ def torch_to_mlx(
def load_model(
name: str,
download_root: str = None,
dtype : mx.Dtype = mx.float32,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root))
return torch_to_mlx(load_torch_model(name, download_root), dtype)

View File

@@ -43,9 +43,9 @@ class ModelHolder:
model_name = None
@classmethod
def get_model(cls, model: str):
def get_model(cls, model: str, dtype : mx.Dtype):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model)
cls.model = load_model(model, dtype=dtype)
cls.model_name = model
return cls.model
@@ -114,9 +114,8 @@ def transcribe(
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
model = ModelHolder.get_model(model)
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model, dtype)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)

View File

@@ -37,6 +37,10 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
@@ -94,17 +98,17 @@ class ResidualAttentionBlock(nn.Module):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp1 = nn.Linear(n_state, n_mlp)
self.mlp2 = nn.Linear(n_mlp, n_state)
self.mlp_ln = nn.LayerNorm(n_state)
self.mlp_ln = LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None)
@@ -119,15 +123,15 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self._positional_embedding = sinusoids(n_ctx, n_state)
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
self.ln_post = nn.LayerNorm(n_state)
self.ln_post = LayerNorm(n_state)
def __call__(self, x):
x = nn.gelu(self.conv1(x))
@@ -144,7 +148,7 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
):
super().__init__()
@@ -155,8 +159,8 @@ class TextDecoder(nn.Module):
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
self.ln = nn.LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx)
self.ln = LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
def __call__(self, x, xa, kv_cache=None):
"""
@@ -181,7 +185,7 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
@@ -190,6 +194,7 @@ class Whisper(nn.Module):
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
dtype,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
@@ -197,6 +202,7 @@ class Whisper(nn.Module):
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
dtype,
)
def embed_audio(self, mel):