Merge branch 'ml-explore:main' into main

This commit is contained in:
Pawel Kowalski 2023-12-13 12:18:12 +01:00 committed by GitHub
commit 47b0685c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 307 additions and 188 deletions

View File

@ -9,9 +9,12 @@ Some more useful examples include:
- [Transformer language model](transformer_lm) training.
- Large scale text generation with [LLaMA](llama) or [Mistral](mistral).
- Mixture-of-experts (MoE) language model with [Mixtral 8x7B](mixtral)
- Parameter efficient fine-tuning with [LoRA](lora).
- Generating images with [Stable Diffusion](stable_diffusion).
- Speech recognition with [OpenAI's Whisper](whisper).
- Bidirectional language understanding with [BERT](bert)
- Semi-supervised learning on graph-structured data with [GCN](gcn).
## Contributing

View File

@ -25,7 +25,7 @@ def run(bert_model: str):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the BERT model using HuggingFace Transformers."
description="Run the BERT model using Hugging Face Transformers."
)
parser.add_argument(
"--bert-model",

View File

@ -1,8 +1,9 @@
# LLaMA
# Llama
An example of generating text with LLaMA using MLX.
An example of generating text with Llama (1 or 2) using MLX.
LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters.
Llama is a set of open source language models from Meta AI Research[^1][^2]
ranging from 7B to 70B parameters.
### Setup
@ -13,28 +14,34 @@ pip install -r requirements.txt
```
Next, download and convert the model. If you do not have access to the model
weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta.
weights you will need to request access from Meta:
- [Request Llama v1](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
- [Request Llama v2](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
Alternatively, you can also download a select converted checkpoints from the [mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging Face and skip the conversion step.
Alternatively, you can also download a select converted checkpoints from the
[mlx-llama](https://huggingface.co/mlx-llama) community organisation on Hugging
Face and skip the conversion step.
Convert the weights with:
```
python convert.py <path_to_torch_weights> <path_to_mlx_llama_weights.npz>
python convert.py --model_path <path_to_torch_model>
```
The conversion script will save the converted weights in the same location.
### Run
Once you've converted the weights to MLX format, you can interact with the
LLaMA model:
LlaMA model:
```
python llama.py <path_to_mlx_llama_weights.npz> <path_to_tokenizer.model> "hello"
python llama.py <path_to_model> <path_to_tokenizer.model> "hello"
```
Run `python llama.py --help` for more details.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^1]: For Llama v1 refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.
[^2]: For Llama v2 refer to the [blob post](https://ai.meta.com/llama/)

View File

@ -1,53 +1,59 @@
# Copyright © 2023 Apple Inc.
import argparse
from itertools import starmap
import collections
import glob
from pathlib import Path
import numpy as np
import torch
SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
SHARD_SECOND = ["tok_embeddings", "wo", "w2"]
SHARD_WEIGHTS = set(SHARD_FIRST + SHARD_SECOND)
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
def shard_key(k):
keys = k.split(".")
if len(keys) < 2:
return None
return keys[-2]
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return (
key,
value.numpy()
if value.dtype != torch.bfloat16
else value.to(torch.float32).numpy(),
)
def unshard(k, v):
wn = shard_key(k)
if wn not in SHARD_WEIGHTS:
return v
elif wn in SHARD_FIRST:
axis = 0
elif wn in SHARD_SECOND:
axis = 1
else:
raise ValueError("Invalid weight name")
return np.concatenate(v, axis=axis)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
parser.add_argument(
"--model_path",
help="Path to the Torch model. The MLX weights will also be saved there.",
)
args = parser.parse_args()
state = torch.load(args.torch_weights, map_location=torch.device('cpu'))
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)
model_path = Path(args.model_path)
torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
weights = collections.defaultdict(list)
for wf in torch_files:
state = torch.load(wf, map_location=torch.device("cpu"))
for k, v in state.items():
v = v.to(torch.float16).numpy()
if shard_key(k) in SHARD_WEIGHTS:
weights[k].append(v)
else:
weights[k] = v
out_file = str(model_path / "weights.npz")
for k, v in weights.items():
weights[k] = unshard(k, v)
np.savez(out_file, **weights)

View File

@ -1,8 +1,10 @@
# Copyright © 2023 Apple Inc.
import argparse
import math
import numpy as np
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
import time
@ -11,33 +13,71 @@ import mlx.nn as nn
from mlx.utils import tree_unflatten
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
self.num_heads = num_heads
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(args.head_dim, traditional=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
keys, values = map(repeat, (keys, values))
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
@ -48,86 +88,87 @@ class LlamaAttention(nn.Module):
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out, cache
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
def __init__(self, args: ModelArgs):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
self.args = args
self.vocab_size = args.vocab_size
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
mask = mask.astype(self.tok_embeddings.weight.dtype)
x = self.embedding(x)
x = self.tok_embeddings(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
return self.output(x)
def generate(self, x, temp=1.0):
cache = []
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
mask = mask.astype(self.tok_embeddings.weight.dtype)
# First we process the prompt x the same was as in __call__ but
# save the caches in cache
x = self.embedding(x)
x = self.tok_embeddings(x)
for l in self.layers:
x, c = l(x, mask=mask)
# We store the per layer cache in a simple python list
cache.append(c)
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.out_proj(x[:, -1])
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
# y now has size [1]
@ -145,14 +186,14 @@ class Llama(nn.Module):
# dimension of 1
x = y[:, None]
x = self.embedding(x)
x = self.tok_embeddings(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = self.output(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
yield y
@ -261,20 +302,33 @@ def few_shot_generate(args):
def load_model(model_path):
weights = mx.load(model_path)
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
num_heads = dims // 128
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
vocab_size = weights["out_proj.weight"].shape[-1]
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
with open(model_path / "params.json", "r") as f:
config = json.loads(f.read())
n_heads = config["n_heads"]
if "n_kv_heads" not in config:
config["n_kv_heads"] = n_heads
if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config:
config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
if config.get("vocab_size", -1) < 0:
config["vocab_size"] = weights["output.weight"].shape[-1]
unused = ["multiple_of", "ffn_dim_multiplie"]
for k in unused:
if k in config:
config.pop(k)
model = Llama(ModelArgs(**config))
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument("model", help="The model file containing MLX weights")
parser.add_argument(
"model", help="Path to the model directory containing the MLX weights"
)
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model")
parser.add_argument(

View File

@ -1,2 +1,3 @@
mlx
sentencepiece
torch

View File

@ -14,7 +14,7 @@ For example with Homebrew:
brew install git-lfs
```
Download the models from HuggingFace:
Download the models from Hugging Face:
```
git clone https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen
@ -46,7 +46,7 @@ rm mixtral-8x7b-32kseqlen/*.pth*
As easy as:
```
python mixtral.py --model_path mixtral mixtral-8x7b-32kseqlen/
python mixtral.py --model_path mixtral-8x7b-32kseqlen/
```
[^mixtral]: Refer to Mistral's [blog post](https://mistral.ai/news/mixtral-of-experts/) for more details.

View File

@ -16,7 +16,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
model_path = Path(args.model_path)
state = torch.load(str(model_path / "consolidated.00.pt"))
state = torch.load(str(model_path / "consolidated.00.pth"))
np.savez(
str(model_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()},

View File

@ -1,9 +1,9 @@
Stable Diffusion
================
Stable Diffusion in MLX. The implementation was ported from Huggingface's
Stable Diffusion in MLX. The implementation was ported from Hugging Face's
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
and using the weights available on the Huggingface Hub by Stability AI at
and using the weights available on the Hugging Face Hub by Stability AI at
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
![out](generated-mlx.png)

View File

@ -193,7 +193,7 @@ def _check_key(key: str, part: str):
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion UNet from Huggingface Hub."""
"""Load the stable diffusion UNet from Hugging Face Hub."""
_check_key(key, "load_unet")
# Download the config and create the model
@ -223,7 +223,7 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion text encoder from Huggingface Hub."""
"""Load the stable diffusion text encoder from Hugging Face Hub."""
_check_key(key, "load_text_encoder")
# Download the config and create the model
@ -250,7 +250,7 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion autoencoder from Huggingface Hub."""
"""Load the stable diffusion autoencoder from Hugging Face Hub."""
_check_key(key, "load_autoencoder")
# Download the config and create the model
@ -279,7 +279,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_diffusion_config(key: str = _DEFAULT_MODEL):
"""Load the stable diffusion config from Huggingface Hub."""
"""Load the stable diffusion config from Hugging Face Hub."""
_check_key(key, "load_diffusion_config")
diffusion_config = _MODELS[key]["diffusion_config"]

View File

@ -81,7 +81,7 @@ class Tokenizer:
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Huggingface does
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())

View File

@ -57,12 +57,13 @@ if __name__ == "__main__":
if sys.argv[1] == "--all":
models = ["tiny", "small", "medium", "large"]
feat_time = timer(feats)
print(f"\nFeature time {feat_time:.3f}")
mels = feats()[None].astype(mx.float16)
for model_name in models:
feat_time = timer(feats)
print(f"\nModel: {model_name.upper()}")
print(f"\nFeature time {feat_time:.3f}")
mels = feats()[None]
tokens = mx.array(
[
50364,
@ -96,7 +97,7 @@ if __name__ == "__main__":
],
mx.int32,
)[None]
model = load_models.load_model(f"{model_name}")
model = load_models.load_model(f"{model_name}", dtype=mx.float16)
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)

View File

@ -36,7 +36,7 @@ def forward_mlx(model, mels, tokens):
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = load_models.load_model("tiny")
cls.model = load_models.load_model("tiny", dtype=mx.float32)
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data)
cls.mels = audio.log_mel_spectrogram(data)
@ -52,13 +52,22 @@ class TestWhisper(unittest.TestCase):
torch_logits = forward_torch(torch_model, mels, tokens)
mlx_model = load_models.torch_to_mlx(torch_model)
mlx_model = load_models.torch_to_mlx(torch_model, mx.float32)
mlx_logits = forward_mlx(mlx_model, mels, tokens)
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
def test_fp16(self):
mlx_model = load_models.load_model("tiny", dtype=mx.float16)
dims = mlx_model.dims
mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
logits = mlx_model(mels, tokens)
self.assertEqual(logits.dtype, mx.float16)
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id")
options = decoding.DecodingOptions(task="lang_id", fp16=False)
result = decoding.decode(self.model, self.mels, options)
self.assertEqual(result.language, "en")
self.assertEqual(len(result.language_probs), 99)
@ -67,7 +76,7 @@ class TestWhisper(unittest.TestCase):
)
def test_decode_greedy(self):
result = decoding.decode(self.model, self.mels)
result = decoding.decode(self.model, self.mels, fp16=False)
self.assertEqual(result.language, "en")
self.assertEqual(
result.tokens,
@ -114,7 +123,7 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
# Small temp should give the same results
result = decoding.decode(self.model, self.mels, temperature=1e-8)
result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False)
self.assertEqual(
result.text,
@ -128,7 +137,7 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO)
result = whisper.transcribe(TEST_AUDIO, fp16=False)
self.assertEqual(
result["text"],
(
@ -147,7 +156,7 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file)
result = whisper.transcribe(audio_file, fp16=False)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)

View File

@ -11,7 +11,6 @@ import numpy as np
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
@ -81,7 +80,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(n_mels: int = N_MELS) -> mx.array:
def mel_filters(n_mels: int) -> mx.array:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
@ -89,9 +88,10 @@ def mel_filters(n_mels: int = N_MELS) -> mx.array:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
return mx.load(filename)[f"mel_{n_mels}"]
@ -130,7 +130,7 @@ def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="re
def log_mel_spectrogram(
audio: Union[str, np.ndarray],
n_mels: int = N_MELS,
n_mels: int = 80,
padding: int = 0,
):
"""

View File

@ -33,7 +33,9 @@ def detect_language(
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
@ -110,7 +112,7 @@ class DecodingOptions:
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = False # use fp16 for most of the calculation
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
@ -141,7 +143,7 @@ class Inference:
logits, self.kv_cache = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits
return logits.astype(mx.float32)
def rearrange_kv_cache(self, source_indices):
"""Update the key-value cache according to the updated beams"""
@ -401,7 +403,10 @@ class DecodingTask:
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
@ -542,7 +547,7 @@ class DecodingTask:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
return TypeError(
raise TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)

View File

@ -7,6 +7,7 @@ import warnings
from typing import List
import mlx.core as mx
from mlx.utils import tree_map
import torch
from tqdm import tqdm
@ -25,7 +26,8 @@ _MODELS = {
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
@ -41,7 +43,8 @@ _ALIGNMENT_HEADS = {
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00"
}
@ -163,7 +166,7 @@ def convert(model, rules=None):
def torch_to_mlx(
torch_model: torch_whisper.Whisper,
torch_model: torch_whisper.Whisper, dtype: mx.Dtype = mx.float16,
) -> whisper.Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
@ -182,7 +185,8 @@ def torch_to_mlx(
params = convert(torch_model, rules)
mlx_model = whisper.Whisper(torch_model.dims)
mlx_model = whisper.Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
return mlx_model
@ -190,5 +194,6 @@ def torch_to_mlx(
def load_model(
name: str,
download_root: str = None,
dtype : mx.Dtype = mx.float32,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root))
return torch_to_mlx(load_torch_model(name, download_root), dtype)

View File

@ -109,6 +109,7 @@ LANGUAGES = {
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
@ -125,6 +126,7 @@ TO_LANGUAGE_CODE = {
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
@ -133,6 +135,7 @@ class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
@ -147,7 +150,7 @@ class Tokenizer:
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
@ -213,10 +216,13 @@ class Tokenizer:
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
if token := self.special_tokens.get(f"<|{self.language}|>", None):
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {self.language} not found in tokenizer.")
raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
@ -224,7 +230,7 @@ class Tokenizer:
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
return tuple(result)[: self.num_languages]
@cached_property
def all_language_codes(self) -> Tuple[str]:
@ -271,7 +277,7 @@ class Tokenizer:
return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}:
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
@ -324,7 +330,7 @@ class Tokenizer:
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2"):
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
with open(vocab_path) as fid:
ranks = {
@ -337,7 +343,7 @@ def get_encoding(name: str = "gpt2"):
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
@ -364,6 +370,7 @@ def get_encoding(name: str = "gpt2"):
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
@ -384,6 +391,8 @@ def get_tokenizer(
language = None
task = None
encoding = get_encoding(name=encoding_name)
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
return Tokenizer(encoding=encoding, language=language, task=task)
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)

View File

@ -234,7 +234,8 @@ class Whisper(nn.Module):
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
@ -267,7 +268,11 @@ class Whisper(nn.Module):
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""

View File

@ -43,9 +43,9 @@ class ModelHolder:
model_name = None
@classmethod
def get_model(cls, model: str):
def get_model(cls, model: str, dtype : mx.Dtype):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model)
cls.model = load_model(model, dtype=dtype)
cls.model_name = model
return cls.model
@ -114,12 +114,11 @@ def transcribe(
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
model = ModelHolder.get_model(model)
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(model, dtype)
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-2] - N_FRAMES
if verbose:
@ -150,7 +149,12 @@ def transcribe(
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
def decode_with_fallback(segment: mx.array) -> DecodingResult:
temperatures = (

View File

@ -37,6 +37,10 @@ def sinusoids(length, channels, max_timescale=10000):
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
@ -94,17 +98,17 @@ class ResidualAttentionBlock(nn.Module):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp1 = nn.Linear(n_state, n_mlp)
self.mlp2 = nn.Linear(n_mlp, n_state)
self.mlp_ln = nn.LayerNorm(n_state)
self.mlp_ln = LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None)
@ -119,15 +123,15 @@ class ResidualAttentionBlock(nn.Module):
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self._positional_embedding = sinusoids(n_ctx, n_state)
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
self.ln_post = nn.LayerNorm(n_state)
self.ln_post = LayerNorm(n_state)
def __call__(self, x):
x = nn.gelu(self.conv1(x))
@ -144,7 +148,7 @@ class AudioEncoder(nn.Module):
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, dtype: mx.Dtype = mx.float16,
):
super().__init__()
@ -155,8 +159,8 @@ class TextDecoder(nn.Module):
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
self.ln = nn.LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx)
self.ln = LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype)
def __call__(self, x, xa, kv_cache=None):
"""
@ -181,7 +185,7 @@ class TextDecoder(nn.Module):
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
@ -190,6 +194,7 @@ class Whisper(nn.Module):
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
dtype,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
@ -197,6 +202,7 @@ class Whisper(nn.Module):
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
dtype,
)
def embed_audio(self, mel):
@ -210,7 +216,11 @@ class Whisper(nn.Module):
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
detect_language = detect_language_function
decode = decode_function