a few examples

This commit is contained in:
Awni Hannun 2023-11-29 08:17:26 -08:00
parent e31d82d3ed
commit b243c1d8f4
32 changed files with 105181 additions and 2 deletions

129
.gitignore vendored Normal file
View File

@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

5
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,5 @@
repos:
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black

View File

@ -1,2 +1,14 @@
# mlx-examples
Examples in the MLX framework
# Transformer LM
This is an example of a decoder-only Transformer LM. The only dependency is
MLX.
Run the example on the GPU with:
```
python main.py --gpu
```
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
To run the PyTorch, Jax or TensorFlowexamples install the respective framework.

18
mnist/README.md Normal file
View File

@ -0,0 +1,18 @@
# MNIST
This example shows how to run some simple models on MNIST. The only
dependency is MLX.
Run the example with:
```
python main.py
```
By default the example runs on the CPU. To run on the GPU, use:
```
python main.py --gpu
```
To run the PyTorch or Jax examples install the respective framework.

80
mnist/jax_main.py Normal file
View File

@ -0,0 +1,80 @@
import jax
import jax.numpy as jnp
import functools
import time
import mnist
def init_model(key, num_layers, input_dim, hidden_dim, output_dim):
params = []
layer_sizes = [hidden_dim] * num_layers
for idim, odim in zip([input_dim] + layer_sizes, layer_sizes + [output_dim]):
key, wk = jax.random.split(key, 2)
W = 1e-2 * jax.random.normal(wk, (idim, odim))
b = jnp.zeros((odim,))
params.append((W, b))
return params
def feed_forward(params, X):
for W, b in params[:-1]:
X = jnp.maximum(X @ W + b, 0)
W, b = params[-1]
return X @ W + b
def loss_fn(params, X, y):
logits = feed_forward(params, X)
logits = jax.nn.log_softmax(logits, 1)
return -jnp.mean(logits[jnp.arange(y.size), y])
@jax.jit
def eval_fn(params, X, y):
logits = feed_forward(params, X)
return jnp.mean(jnp.argmax(logits, axis=1) == y)
def batch_iterate(key, batch_size, X, y):
perm = jax.random.permutation(key, y.size)
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
if __name__ == "__main__":
seed = 0
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
# Load the data
train_images, train_labels, test_images, test_labels = mnist.mnist()
# Load the model
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
params = init_model(
subkey, num_layers, train_images.shape[-1], hidden_dim, num_classes
)
loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
update_fn = jax.jit(
functools.partial(jax.tree_map, lambda p, g: p - learning_rate * g)
)
for e in range(num_epochs):
tic = time.perf_counter()
key, subkey = jax.random.split(key)
for X, y in batch_iterate(subkey, batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(params, X, y)
params = update_fn(params, grads)
accuracy = eval_fn(params, test_images, test_labels)
toc = time.perf_counter()
print(
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
f" Time {toc - tic:.3f} (s)"
)

88
mnist/main.py Normal file
View File

@ -0,0 +1,88 @@
import argparse
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mnist
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = mx.maximum(l(x), 0.0)
return self.layers[-1](x)
def loss_fn(model, X, y):
return mx.mean(nn.losses.cross_entropy(model(X), y))
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
def main():
seed = 0
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
np.random.seed(seed)
# Load the data
train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)
for e in range(num_epochs):
tic = time.perf_counter()
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
accuracy = eval_fn(model, test_images, test_labels)
toc = time.perf_counter()
print(
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
f" Time {toc - tic:.3f} (s)"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main()

67
mnist/mnist.py Normal file
View File

@ -0,0 +1,67 @@
import gzip
import numpy as np
import os
import pickle
from urllib import request
def mnist(save_dir="/tmp"):
"""
Load the MNIST dataset in 4 tensors: train images, train labels,
test images, and test labels.
Checks `save_dir` for already downloaded data otherwise downloads.
Download code modified from:
https://github.com/hsjeong5/MNIST-for-Numpy
"""
def download_and_save(save_file):
base_url = "http://yann.lecun.com/exdb/mnist/"
filename = [
["training_images", "train-images-idx3-ubyte.gz"],
["test_images", "t10k-images-idx3-ubyte.gz"],
["training_labels", "train-labels-idx1-ubyte.gz"],
["test_labels", "t10k-labels-idx1-ubyte.gz"],
]
mnist = {}
for name in filename:
out_file = os.path.join("/tmp", name[1])
request.urlretrieve(base_url + name[1], out_file)
for name in filename[:2]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
-1, 28 * 28
)
for name in filename[-2:]:
out_file = os.path.join("/tmp", name[1])
with gzip.open(out_file, "rb") as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open(save_file, "wb") as f:
pickle.dump(mnist, f)
save_file = os.path.join(save_dir, "mnist.pkl")
if not os.path.exists(save_file):
download_and_save(save_file)
with open(save_file, "rb") as f:
mnist = pickle.load(f)
preproc = lambda x: x.astype(np.float32) / 255.0
mnist["training_images"] = preproc(mnist["training_images"])
mnist["test_images"] = preproc(mnist["test_images"])
return (
mnist["training_images"],
mnist["training_labels"].astype(np.uint32),
mnist["test_images"],
mnist["test_labels"].astype(np.uint32),
)
if __name__ == "__main__":
train_x, train_y, test_x, test_y = mnist()
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
assert train_y.shape == (60000,), "Wrong training set size"
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
assert test_y.shape == (10000,), "Wrong test set size"

88
mnist/torch_main.py Normal file
View File

@ -0,0 +1,88 @@
import argparse
import torch
import time
import mnist
class MLP(torch.nn.Module):
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
super().__init__()
layer_sizes = [hidden_dim] * num_layers
self.layers = torch.nn.ModuleList(
[
torch.nn.Linear(idim, odim)
for idim, odim in zip(
[input_dim] + layer_sizes, layer_sizes + [output_dim]
)
]
)
def forward(self, x):
x = self.layers[0](x)
for l in self.layers[1:]:
x = l(x.relu())
return x
def loss_fn(model, X, y):
logits = model(X)
return torch.nn.functional.cross_entropy(logits, y)
@torch.no_grad()
def eval_fn(model, X, y):
logits = model(X)
return torch.mean((logits.argmax(-1) == y).float())
def batch_iterate(batch_size, X, y, device):
perm = torch.randperm(len(y), device=device)
for s in range(0, len(y), batch_size):
ids = perm[s : s + batch_size]
yield X[ids], y[ids]
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if not args.gpu:
torch.set_num_threads(1)
device = "cpu"
else:
device = "mps"
seed = 0
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1
# Load the data
def to_tensor(x):
if x.dtype != "uint32":
return torch.from_numpy(x).to(device)
else:
return torch.from_numpy(x.astype(int)).to(device)
train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist())
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
opt = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
for e in range(num_epochs):
tic = time.perf_counter()
for X, y in batch_iterate(batch_size, train_images, train_labels, device):
opt.zero_grad()
loss_fn(model, X, y).backward()
opt.step()
accuracy = eval_fn(model, test_images, test_labels)
toc = time.perf_counter()
print(
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
f" Time {toc - tic:.3f} (s)"
)

14
transformer_lm/README.md Normal file
View File

@ -0,0 +1,14 @@
# Transformer LM
This is an example of a decoder-only Transformer LM. The only dependency is
MLX.
Run the example on the GPU with:
```
python main.py --gpu
```
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
To run the PyTorch, Jax or TensorFlowexamples install the respective framework.

View File

@ -0,0 +1,90 @@
import io
import itertools
import numpy as np
import os
from urllib import request
import zipfile
def load_dataset(dataname):
if dataname == "ptb":
return ptb()
elif dataname == "wikitext2":
return wikitext(dataset="2")
else:
return wikitext(dataset="103")
def _load(save_dir, filenames):
# *NB* First file is expected to be the training set
with open(os.path.join(save_dir, filenames[0]), "r") as fid:
vocab = set(t for l in fid.readlines() for t in l.strip().split(" "))
eos = "<eos>"
vocab.add(eos)
vocab = {v: i for i, v in enumerate(vocab)}
def to_array(dataset):
with open(os.path.join(save_dir, dataset), "r") as fid:
lines = (l.strip().split(" ") for l in fid.readlines())
return np.array(
[vocab[w] for line in lines for w in itertools.chain(line, [eos])],
dtype=np.uint32,
)
datasets = [to_array(fn) for fn in filenames]
return vocab, *datasets
def wikitext(dataset="2", save_dir="/tmp"):
"""
Load the WikiText-* language modeling dataset:
https://paperswithcode.com/dataset/penn-treebank
"""
if dataset not in ("2", "103"):
raise ValueError(f'Dataset must be either "2" or "103", got {dataset}')
filenames = ["wiki.train.tokens", "wiki.valid.tokens", "wiki.test.tokens"]
dataname = f"wikitext-{dataset}"
data_dir = os.path.join(save_dir, dataname)
if not os.path.exists(data_dir):
base_url = "https://s3.amazonaws.com/research.metamind.io/wikitext/"
zip_file_url = base_url + dataname + "-v1.zip"
r = request.urlopen(zip_file_url)
with zipfile.ZipFile(io.BytesIO(r.read())) as zf:
zf.extractall(save_dir)
return _load(data_dir, filenames)
def ptb(save_dir="/tmp"):
"""
Load the PTB language modeling dataset:
https://paperswithcode.com/dataset/penn-treebank
"""
filenames = [
"ptb.train.txt",
"ptb.valid.txt",
"ptb.test.txt",
]
def download_and_save(save_dir):
base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/"
for name in filenames:
out_file = os.path.join(save_dir, name)
if not os.path.exists(out_file):
request.urlretrieve(base_url + name, out_file)
save_dir = os.path.join(save_dir, "ptb")
if not os.path.exists(save_dir):
os.mkdir(save_dir)
download_and_save(save_dir)
return _load(save_dir, filenames)
if __name__ == "__main__":
vocab, train, val, test = ptb()
assert len(vocab) == 10000, "PTB: Wrong vocab size"
vocab, train, val, test = wikitext()
assert len(vocab) == 33279, "WikiText: Wrong vocab size"

303
transformer_lm/jax_main.py Normal file
View File

@ -0,0 +1,303 @@
import functools
import jax
import jax.numpy as jnp
import math
import numpy as np
import time
from collections import namedtuple
import datasets
from tree_utils import tree_flatten
"""
Some TODOs for this model:
- Positional encodings
- Dropout
- Adam optimizer
- Option for bigger datasets (wikitext / librispeech text < c4 < ...)
"""
RuntimeConfig = namedtuple("RuntimeConfig", "num_heads")
def embedding_init(key, num_embeddings, embed_dim):
return jax.random.uniform(
key, (num_embeddings, embed_dim), minval=-1e-1, maxval=1e-1
)
def embedding_apply(params, X):
return params.take(X, axis=0)
def dense_init(key, in_dim, out_dim, bias=True):
k1, k2 = jax.random.split(key)
scale = math.sqrt(1 / in_dim)
params = [jax.random.uniform(k1, (in_dim, out_dim), minval=-scale, maxval=scale)]
if bias:
params.append(jax.random.uniform(k2, (out_dim,), minval=-scale, maxval=scale))
return params
def dense_apply(params, X):
X = X @ params[0]
if len(params) == 2:
X = X + params[1]
return X
def layernorm_init(key, dim):
return [jnp.zeros((dim,)), jnp.ones((dim,))]
def layernorm_apply(params, X, epsilon=1e-6):
means = jnp.mean(X, axis=-1, keepdims=True)
var = jnp.var(X, axis=-1, keepdims=True)
X = (X - means) / jnp.sqrt(var + epsilon)
beta, gamma = params
return gamma * X + beta
def mlpblock_init(key, dim):
k1, k2 = jax.random.split(key)
return {
"dense1": dense_init(k1, dim, 4 * dim),
"dense2": dense_init(k2, 4 * dim, dim),
}
def mlpblock_apply(params, X):
X = dense_apply(params["dense1"], X)
X = jnp.maximum(X, 0)
# TODO dropout option here
return dense_apply(params["dense2"], X)
def selfattention_init(key, dim):
k1, k2, k3, k4 = jax.random.split(key, 4)
return {
"Q": dense_init(k1, dim, dim, bias=False),
"K": dense_init(k2, dim, dim, bias=False),
"V": dense_init(k3, dim, dim, bias=False),
"out": dense_init(k4, dim, dim, bias=False),
}
def selfattention_apply(params, num_heads, X, mask):
queries = dense_apply(params["Q"], X)
keys = dense_apply(params["K"], X)
values = dense_apply(params["V"], X)
B, L, D = queries.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ jnp.transpose(keys, (0, 1, 3, 2))
scores = jax.nn.softmax(scores + mask, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return dense_apply(params["out"], values_hat)
def transformer_init(key, token_set_size, num_blocks, dim):
key, ek = jax.random.split(key)
params = {"embedding": embedding_init(ek, token_set_size, dim)}
transformer_blocks = []
for b in range(num_blocks):
key, k1, k2, k3, k4 = jax.random.split(key, 5)
transformer_blocks.append(
{
"ln1": layernorm_init(k1, dim),
"ln2": layernorm_init(k2, dim),
"selfattention": selfattention_init(k3, dim),
"mlpblock": mlpblock_init(k4, dim),
}
)
params["transformer_blocks"] = transformer_blocks
params["output"] = dense_init(key, dim, token_set_size)
return params
def create_additive_causal_mask(N):
indices = jnp.arange(N)
mask = jnp.reshape(indices, (-1, 1)) < jnp.reshape(indices, (1, -1))
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
mask = mask.astype(jnp.float32) * -1e9
return mask
def transformer_apply(params, static_params, inputs):
mask = create_additive_causal_mask(inputs.shape[1])
X = embedding_apply(params["embedding"], inputs)
for block in params["transformer_blocks"]:
out = layernorm_apply(block["ln1"], X)
out = selfattention_apply(
block["selfattention"], static_params.num_heads, out, mask
)
X = X + out
out = layernorm_apply(block["ln2"], X)
out = mlpblock_apply(block["mlpblock"], out)
X = X + out
return dense_apply(params["output"], X)
@functools.partial(jax.jit, static_argnames=["static_params", "reduce"])
def loss_fn(params, static_params, inputs, targets, reduce=True):
logits = transformer_apply(params, static_params, inputs)
logits = jax.nn.log_softmax(logits, axis=-1)
sample_indices = jnp.arange(targets.shape[0])[:, None]
token_indices = jnp.arange(targets.shape[1])[None, :]
losses = -logits[sample_indices, token_indices, targets]
return jnp.mean(losses) if reduce else losses.mean(-1)
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(key, batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def main(args):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
config = RuntimeConfig(args.num_heads)
# Load vocab and dataset:
vocab, train, valid, test = datasets.ptb()
# Initialize model:
key, subkey = jax.random.split(jax.random.PRNGKey(args.seed))
params = transformer_init(subkey, len(vocab), args.num_blocks, args.dim)
nparams = sum(x.size for k, x in tree_flatten(params) if "embedding" not in k)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
loss_and_grad_fn = jax.jit(
jax.value_and_grad(loss_fn), static_argnames=["static_params"]
)
update_fn = jax.jit(
functools.partial(jax.tree_map, lambda p, g: p - args.learning_rate * g)
)
def eval_fn(params, dataset):
inputs, targets = to_samples(context_size, dataset)
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
losses = loss_fn(params, config, bx, by, reduce=False)
loss += jnp.sum(losses)
return loss / len(targets)
train_iterator = iterate_batches(subkey, batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
loss, grads = loss_and_grad_fn(params, config, inputs, targets)
losses.append(loss.item())
params = update_fn(params, grads)
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(params, valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(params, test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss.item():.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with Jax.")
parser.add_argument(
"--seed", type=int, default=0, help="Seed for numpy and Jax RNGs."
)
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
main(args)

190
transformer_lm/main.py Normal file
View File

@ -0,0 +1,190 @@
import math
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
import datasets
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
x = self.transformer(x, mask)
return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def main(args):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.load_dataset(args.dataset)
# Initialize model:
model = TransformerLM(len(vocab), args.num_blocks, args.dim, args.num_heads)
mx.eval(model.parameters())
nparams = sum(
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
optimizer = optim.SGD(learning_rate=args.learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(params, dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = self.loss(bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
model.update(optimizer.apply_gradients(grads, model))
mx.simplify(loss, model.parameters())
mx.eval(loss, model.parameters())
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(params, valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(params, test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--dataset",
type=str,
default="ptb",
choices=["ptb", "wikitext2", "wikitext103"],
help="Dataset to train and evaluate on.",
)
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
if not args.gpu:
mx.set_default_device(mx.cpu)
main(args)

249
transformer_lm/tf_main.py Normal file
View File

@ -0,0 +1,249 @@
import math
import time
import numpy as np
import tensorflow as tf
import datasets
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s + batch_size >= inputs.shape[0]:
s = 0
def create_additive_causal_mask(N):
indices = tf.range(N)
mask = tf.reshape(indices, (-1, 1)) < tf.reshape(indices, (1, -1))
return tf.cast(mask, tf.dtypes.float32) * -1e9
class SelfAttention(tf.keras.layers.Layer):
def __init__(self, num_heads, model_dims, context_size):
super().__init__()
self.Wq = tf.keras.layers.Dense(model_dims, use_bias=False)
self.Wk = tf.keras.layers.Dense(model_dims, use_bias=False)
self.Wv = tf.keras.layers.Dense(model_dims, use_bias=False)
self.Wo = tf.keras.layers.Dense(model_dims, use_bias=False)
self.causal_mask = create_additive_causal_mask(context_size)
self.num_heads = num_heads
self.head_dim = model_dims // num_heads
self.scale = tf.constant(1.0 / math.sqrt(self.head_dim))
def call(self, x):
queries = self.Wq(x)
keys = self.Wk(x)
values = self.Wv(x)
B, L, D = x.shape
queries = tf.transpose(
tf.reshape(queries, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
)
keys = tf.transpose(
tf.reshape(keys, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
)
values = tf.transpose(
tf.reshape(values, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
)
scores = (self.scale * queries) @ tf.transpose(keys, (0, 1, 3, 2))
scores = tf.nn.softmax(scores + self.causal_mask, axis=-1)
values = tf.matmul(scores, values)
values_hat = tf.reshape(tf.transpose(values, perm=(0, 2, 1, 3)), (B, L, -1))
return self.Wo(values_hat)
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, num_heads, model_dims, context_size):
super().__init__()
self._ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self._self_attn = SelfAttention(num_heads, model_dims, context_size)
self._ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self._mlp = tf.keras.Sequential(
[
tf.keras.layers.Dense(4 * model_dims, activation="relu"),
tf.keras.layers.Dense(model_dims),
]
)
def call(self, x):
x = x + self._self_attn(self._ln1(x))
x = x + self._mlp(self._ln2(x))
return x
class TransformerLM(tf.keras.Model):
def __init__(self, vocab_size, num_layers, num_heads, model_dims, context_size):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, model_dims)
self.transformer = tf.keras.Sequential(
[
EncoderLayer(num_heads, model_dims, context_size)
for _ in range(num_layers)
]
)
self.projection = tf.keras.layers.Dense(vocab_size)
def call(self, x):
x = self.embedding(x)
x = self.transformer(x)
x = self.projection(x)
return x
def main(args, device):
with tf.device(device):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.ptb()
# Initialize model:
transformer = TransformerLM(
len(vocab), args.num_blocks, args.num_heads, args.dim, context_size
)
transformer.compile(
optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=args.learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
transformer.build((batch_size, context_size))
nparams = sum(
np.prod(p.shape) for p in transformer.trainable_weights[1:]
) # [1:] to skip the embedding
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def eval_fn(dataset):
inputs, targets = to_samples(context_size, dataset)
loss = 0
n_batches = 0
for s in range(0, targets.shape[0], batch_size):
if s + batch_size >= targets.shape[0]:
s = targets.shape[0] - 1 - batch_size
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
[bx, by],
)
loss += transformer.test_on_batch(bx, by)
n_batches += 1
return loss / n_batches
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
[inputs, targets],
)
loss = transformer.train_on_batch(inputs, targets)
losses.append(loss)
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
main(args, device="/GPU:0" if args.gpu else "/CPU:0")

View File

@ -0,0 +1,197 @@
import math
import time
import numpy as np
import torch
import datasets
def to_samples(context_size, dataset):
tokens = dataset.size
window_size = context_size + 1 # include target
samples = tokens - window_size + 1
X = np.lib.stride_tricks.as_strided(
dataset,
shape=(samples, window_size),
strides=(dataset.itemsize, dataset.itemsize),
)
return X[:, :-1], X[:, 1:]
def iterate_batches(batch_size, context_size, dataset):
inputs, targets = to_samples(context_size, dataset)
s = 0
while True:
if s == 0:
# Reset permutation:
perm = np.random.permutation(inputs.shape[0])
ids = perm[s : s + batch_size]
yield inputs[ids], targets[ids]
s += batch_size
if s >= inputs.shape[0]:
s = 0
def create_additive_causal_mask(N, device):
# torch.nn.Transformer.generate_square_subsequent_mask
# gives NaNs with `device="mps"`
indices = torch.arange(N, device=device)
mask = indices.reshape((-1, 1)) < indices.reshape((1, -1))
return mask.to(torch.float32) * -1e9
class TransformerLM(torch.nn.Module):
def __init__(self, vocab_size, num_layers, num_heads, model_dims):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, model_dims)
self.transformer = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
model_dims,
num_heads,
4 * model_dims,
dropout=0.0,
norm_first=True,
batch_first=True,
),
num_layers,
)
self.projection = torch.nn.Linear(model_dims, vocab_size)
def forward(self, x):
mask = create_additive_causal_mask(x.shape[1], device=x.device)
x = self.embedding(x)
x = self.transformer(x, mask=mask)
x = self.projection(x)
return x
def main(args, device):
batch_size = args.batch_size
context_size = args.context_size
steps_per_eval = args.steps_per_eval
steps_per_report = args.steps_per_report
# Load vocab and dataset:
vocab, train, valid, test = datasets.ptb()
# Initialize model:
transformer = TransformerLM(len(vocab), args.num_blocks, args.num_heads, args.dim)
transformer = transformer.to(device)
optim = torch.optim.SGD(transformer.parameters(), lr=args.learning_rate, momentum=0)
nparams = sum(
p.numel() for n, p in transformer.named_parameters() if "embedding" not in n
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
@torch.no_grad()
def eval_fn(dataset):
inputs, targets = to_samples(context_size, dataset)
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(lambda x: torch.from_numpy(x.astype(int)).to(device), [bx, by])
logits = transformer(bx)
losses = torch.nn.functional.cross_entropy(
logits.flatten(0, 1), by.flatten(), reduction="none"
)
losses = losses.view(-1, by.shape[-1]).mean(-1)
loss += losses.sum().item()
return loss / len(targets)
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(
lambda x: torch.from_numpy(x.astype(int)).to(device), [inputs, targets]
)
optim.zero_grad()
logits = transformer(inputs)
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1), targets.flatten()
)
loss.backward()
optim.step()
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f}"
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
f"Val loss {val_loss:.3f}, "
f"Val ppl {math.exp(val_loss):.3f}, "
f"Val took {(toc - tic):.3f}s, "
)
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
parser.add_argument(
"--context_size",
type=int,
default=1024,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=16,
help="Number of heads used for multi-head attention",
)
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100000, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps_per_eval",
type=int,
default=1000,
help="Number of training steps between validations.",
)
parser.add_argument(
"--eval_test",
action="store_true",
help="Evaluate on the test set after training",
)
args = parser.parse_args()
main(args, device="mps" if args.gpu else "cpu")

27
whisper/README.md Normal file
View File

@ -0,0 +1,27 @@
# whisper
Whisper in MLX.
First install the dependencies:
(TODO, MLX install link / command / add to requirements.txt)
```
pip install -r requirements.txt
```
Install [`ffmpeg`](https://ffmpeg.org/):
```bash
# on MacOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```
Then transcribe audio with:
```
import whisper
text = whisper.transcribe(speech_file)["text"]
```

91
whisper/benchmark.py Normal file
View File

@ -0,0 +1,91 @@
import time
import mlx.core as mx
from whisper import load_models
from whisper import audio
from whisper import decoding
from whisper import transcribe
audio_file = "whisper/assets/ls_test.flac"
def timer(fn, *args):
for _ in range(5):
fn(*args)
num_its = 10
tic = time.perf_counter()
for _ in range(num_its):
fn(*args)
toc = time.perf_counter()
return (toc - tic) / num_its
def feats():
data = audio.load_audio(audio_file)
data = audio.pad_or_trim(data)
mels = audio.log_mel_spectrogram(data)
mx.eval(mels)
return mels
def model_forward(model, mels, tokens):
logits = model(mels, tokens)
mx.eval(logits)
return logits
def decode(model, mels):
return decoding.decode(model, mels)
def everything():
return transcribe(audio_file)
if __name__ == "__main__":
feat_time = timer(feats)
print(f"Feature time {feat_time:.3f}")
mels = feats()[None]
tokens = mx.array(
[
50364,
1396,
264,
665,
5133,
23109,
25462,
264,
6582,
293,
750,
632,
42841,
292,
370,
938,
294,
4054,
293,
12653,
356,
50620,
50620,
23563,
322,
3312,
13,
50680,
],
mx.int32,
)[None]
model = load_models.load_model("tiny")
model_forward_time = timer(model_forward, model, mels, tokens)
print(f"Model forward time {model_forward_time:.3f}")
decode_time = timer(decode, model, mels)
print(f"Decode time {decode_time:.3f}")
everything_time = timer(everything)
print(f"Everything time {everything_time:.3f}")

6
whisper/requirements.txt Normal file
View File

@ -0,0 +1,6 @@
numba
numpy
torch
tqdm
more-itertools
tiktoken==0.3.3

270
whisper/test.py Normal file
View File

@ -0,0 +1,270 @@
import unittest
import mlx.core as mx
import numpy as np
import os
import subprocess
import torch
import whisper
import whisper.audio as audio
import whisper.load_models as load_models
import whisper.torch_whisper as torch_whisper
import whisper.decoding as decoding
TEST_AUDIO = "whisper/assets/ls_test.flac"
def forward_torch(model, mels, tokens):
mels = torch.Tensor(mels).to(torch.float32)
tokens = torch.Tensor(tokens).to(torch.int32)
with torch.no_grad():
logits = model.forward(mels, tokens)
return logits.numpy()
def forward_mlx(model, mels, tokens):
mels = mx.array(mels.transpose(0, 2, 1))
tokens = mx.array(tokens, mx.int32)
logits = model(mels, tokens)
return np.array(logits)
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = load_models.load_model("tiny")
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data)
cls.mels = audio.log_mel_spectrogram(data)
def test_torch_mlx(self):
np.random.seed(10)
torch_model = load_models.load_torch_model("tiny")
dims = torch_model.dims
mels = np.random.randn(1, dims.n_mels, 3_000)
tokens = np.random.randint(0, dims.n_vocab, (1, 20))
torch_logits = forward_torch(torch_model, mels, tokens)
mlx_model = load_models.torch_to_mlx(torch_model)
mlx_logits = forward_mlx(mlx_model, mels, tokens)
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
def test_decode_lang(self):
options = decoding.DecodingOptions(task="lang_id")
result = decoding.decode(self.model, self.mels, options)
self.assertEqual(result.language, "en")
self.assertEqual(len(result.language_probs), 99)
self.assertAlmostEqual(
result.language_probs["en"], 0.9947282671928406, places=5
)
def test_decode_greedy(self):
result = decoding.decode(self.model, self.mels)
self.assertEqual(result.language, "en")
self.assertEqual(
result.tokens,
[
50364,
1396,
264,
665,
5133,
23109,
25462,
264,
6582,
293,
750,
632,
42841,
292,
370,
938,
294,
4054,
293,
12653,
356,
50620,
50620,
23563,
322,
3312,
13,
50680,
],
)
self.assertEqual(
result.text,
(
"Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
# Small temp should give the same results
result = decoding.decode(self.model, self.mels, temperature=1e-8)
self.assertEqual(
result.text,
(
"Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO)
self.assertEqual(
result["text"],
(
" Then the good soul openly sorted the boat and she "
"had buoyed so long in secret and bravely stretched on alone."
),
)
def test_transcribe_alice(self):
audio_file = os.path.join(
os.path.expanduser("~"),
".cache/whisper/alice.mp3",
)
if not os.path.exists(audio_file):
print("To run this test download the alice in wonderland audiobook:")
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)
expected_5 = {
"id": 5,
"seek": 2800,
"start": 40.0,
"end": 46.0,
"text": " Oh my poor little feet, I wonder who will put on your shoes and stockings for you now tears.",
"tokens": [
50964,
876,
452,
4716,
707,
3521,
11,
286,
2441,
567,
486,
829,
322,
428,
6654,
293,
4127,
1109,
337,
291,
586,
10462,
13,
51264,
],
"temperature": 0.0,
"avg_logprob": -0.19670599699020386,
"compression_ratio": 1.5991379310344827,
"no_speech_prob": 0.09746722131967545,
}
expected_73 = {
"id": 73,
"seek": 70700,
"start": 707.0,
"end": 715.0,
"text": " let us get to the shore, and then I'll tell you my history, and you'll understand why it is that I hate cats and dogs.",
"tokens": [
50364,
718,
505,
483,
281,
264,
17805,
11,
293,
550,
286,
603,
980,
291,
452,
2503,
11,
293,
291,
603,
1223,
983,
309,
307,
300,
286,
4700,
11111,
293,
7197,
13,
50764,
],
"temperature": 0.0,
"avg_logprob": -0.1350895343440594,
"compression_ratio": 1.6208333333333333,
"no_speech_prob": 0.002246702555567026,
}
def check_segment(seg, expected):
for k, v in expected.items():
if isinstance(v, float):
self.assertAlmostEqual(seg[k], v, places=3)
else:
self.assertEqual(seg[k], v)
# Randomly check a couple of segments
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)
class TestAudio(unittest.TestCase):
def test_load(self):
data = audio.load_audio(TEST_AUDIO)
data_8k = audio.load_audio(TEST_AUDIO, 8000)
n = 106640
self.assertTrue(data.shape, (n,))
self.assertTrue(data.dtype, np.float32)
self.assertTrue(data_8k.shape, (n // 2,))
def test_pad(self):
data = audio.load_audio(TEST_AUDIO)
data = audio.pad_or_trim(data, 20_000)
self.assertTrue(data.shape, [20_000])
def test_mel_spec(self):
mels = audio.log_mel_spectrogram(TEST_AUDIO)
self.assertTrue(mels.shape, [80, 400])
self.assertTrue(mels.dtype, mx.float32)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,4 @@
from . import load_models
from . import audio
from . import decoding
from .transcribe import transcribe

View File

@ -0,0 +1,10 @@
#!/bin/bash
audio_file=$HOME/.cache/whisper/alice.mp3
echo $audio_file
zipf=alice_in_wonderland_librivox_64kb_mp3.zip
url=https://www.archive.org/download/alice_in_wonderland_librivox/
curl -LO $url/$zipf
unzip $zipf
mv wonderland_ch_02_64kb.mp3 $audio_file
rm wonderland_* $zipf

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

173
whisper/whisper/audio.py Normal file
View File

@ -0,0 +1,173 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union
import mlx.core as mx
import numpy as np
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if array.shape[axis] > length:
sl = [slice(None)] * array.ndim
sl[axis] = slice(0, length)
array = array[tuple(sl)]
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
pad_fn = mx.pad if isinstance(array, mx.array) else np.pad
array = pad_fn(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(n_mels: int = N_MELS) -> mx.array:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
return mx.load(filename)[f"mel_{n_mels}"]
@lru_cache(maxsize=None)
def hanning(size):
return mx.array(np.hanning(size + 1)[:-1])
def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="reflect"):
if nfft is None:
nfft = nperseg
if noverlap is None:
noverlap = nfft // 4
def _pad(x, padding, pad_mode="constant"):
if pad_mode == "constant":
return mx.pad(x, [(padding, padding)])
elif pad_mode == "reflect":
prefix = x[1 : padding + 1][::-1]
suffix = x[-(padding + 1) : -1][::-1]
return mx.concatenate([prefix, x, suffix])
else:
raise ValueError(f"Invalid pad_mode {pad_mode}")
padding = nperseg // 2
x = _pad(x, padding, pad_mode)
strides = [noverlap, 1]
t = (x.size - nperseg + noverlap) // noverlap
shape = [t, nfft]
x = mx.as_strided(x, shape=shape, strides=strides)
return mx.fft.rfft(x * window)
def log_mel_spectrogram(
audio: Union[str, np.ndarray],
n_mels: int = N_MELS,
padding: int = 0,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, mx.array], shape = (*)
The path to audio or either a NumPy or mlx array containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
Returns
-------
mx.array, shape = (80, n_frames)
An array that contains the Mel spectrogram
"""
device = mx.default_device()
mx.set_default_device(mx.cpu)
if not isinstance(audio, mx.array):
if isinstance(audio, str):
audio = load_audio(audio)
audio = mx.array(audio)
if padding > 0:
audio = mx.pad(audio, (0, padding))
window = hanning(N_FFT)
freqs = stft(audio, window, nperseg=N_FFT, noverlap=HOP_LENGTH)
magnitudes = freqs[:-1, :].abs().square()
filters = mel_filters(n_mels)
mel_spec = magnitudes @ filters.T
log_spec = mx.maximum(mel_spec, 1e-10).log10()
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
mx.set_default_device(device)
return log_spec

718
whisper/whisper/decoding.py Normal file
View File

@ -0,0 +1,718 @@
from dataclasses import dataclass, field, replace
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_map
import mlx.nn as nn
import numpy as np
import zlib
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def detect_language(
model: "Whisper", mel: mx.array, tokenizer: Tokenizer = None
) -> Tuple[mx.array, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : mx.array, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2
if single:
mel = mel[None]
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]:
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = mx.array([[tokenizer.sot]] * n_audio) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32)
mask[list(tokenizer.all_language_tokens)] = 0.0
logits += mx.array(mask)
language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1)
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
# whether to perform X->X "transcribe" or X->English "translate"
task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# "alpha" in Google NMT, or None for length norm, when ranking generations
# to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# text or tokens to feed as the prompt or the prefix; for more info:
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt: Optional[Union[str, List[int]]] = None # for the previous context
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = False # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: mx.array
language: str
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = None
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
"""Perform a forward pass on the decoder and return per-token logits"""
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
logits, self.kv_cache = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits
def rearrange_kv_cache(self, source_indices):
"""Update the key-value cache according to the updated beams"""
# update the key/value cache to contain the selected sequences
if source_indices != list(range(len(source_indices))):
self.kv_cache = tree_map(lambda x: x[source_indices], self.kv_cache)
def reset(self):
self.kv_cache = None
class SequenceRanker:
def rank(
self, tokens: List[List[mx.array]], sum_logprobs: List[List[float]]
) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[List[int]]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool, mx.array]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : mx.array, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : mx.array, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : mx.array, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : mx.array, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
sum_logprobs: mx.array, shape = (n_batch)
updated cumulative log probabilities for each sequence
"""
raise NotImplementedError
def finalize(
self, tokens: mx.array, sum_logprobs: mx.array
) -> Tuple[Sequence[Sequence[mx.array]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : mx.array, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : mx.array, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[mx.array]], length = n_audio
sequence of mx.arrays containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
) -> Tuple[mx.array, bool, mx.array]:
if self.temperature == 0:
next_tokens = logits.argmax(axis=-1)
else:
next_tokens = mx.random.categorical(logits=logits / self.temperature)
next_tokens = mx.argmax(logits, axis=-1)
logits = logits.astype(mx.float32)
logprobs = logits - mx.logsumexp(logits, axis=-1)
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
eot_mask = tokens[:, -1] == self.eot
next_tokens = next_tokens * (1 - eot_mask) + self.eot * eot_mask
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=-1)
completed = mx.all(tokens[:, -1] == self.eot)
return tokens, completed, sum_logprobs
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
# make sure each sequence has at least one EOT token at the end
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
return tokens, sum_logprobs.tolist()
class LogitFilter:
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
"""Apply any filtering or masking to logits
Parameters
----------
logits : mx.array, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : mx.array, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int, n_vocab: int):
self.sample_begin = sample_begin
mask = np.zeros(n_vocab, np.float32)
mask[tokenizer.encode(" ") + [tokenizer.eot]] = -np.inf
self.mask = mx.array(mask)
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
if tokens.shape[1] == self.sample_begin:
return logits + self.mask
return logits
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int], n_vocab: int):
mask = np.zeros(n_vocab, np.float32)
mask[list(suppress_tokens)] = -np.inf
self.mask = mx.array(mask)
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
return logits + self.mask
class ApplyTimestampRules(LogitFilter):
def __init__(
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
mask = np.zeros(logits.shape, np.float32)
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
mask[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = sampled_tokens.tolist()
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
mask[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
mask[k, : self.tokenizer.eot] = -np.inf
timestamps = [
i for i, v in enumerate(seq) if v > self.tokenizer.timestamp_begin
]
if len(timestamps) > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
# also force each segment to have a nonzero length, to prevent infinite looping
last_timestamp = timestamps[-1]
if not last_timestamp or penultimate_was_timestamp:
last_timestamp += 1
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
mask[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = logits - mx.logsumexp(logits, axis=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
axis=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
mask[k, : self.tokenizer.timestamp_begin] = -np.inf
return logits + mx.array(mask, logits.dtype)
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual, language=language, task=options.task
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = Inference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
raise NotImplementedError("Beam search decoder is not yet implemented")
# self.decoder = BeamSearchDecoder(
# options.beam_size, tokenizer.eot, self.inference, options.patience
# )
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(
SuppressBlank(self.tokenizer, self.sample_begin, model.dims.n_vocab)
)
if self.options.suppress_tokens:
self.logit_filters.append(
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
)
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append(
ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: mx.array):
if self.options.fp16:
mel = mel.astype(mx.float16)
if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features
def _detect_language(self, audio_features: mx.array, tokens: np.array):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
# write language tokens
tokens[:, self.sot_index + 1] = np.array(lang_tokens)
return languages, lang_probs
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
n_batch = tokens.shape[0]
sum_logprobs: mx.array = mx.zeros(n_batch)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = mx.softmax(
logits[:, self.sot_index].astype(mx.float32), axis=-1
)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logits = logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed, sum_logprobs = self.decoder.update(
tokens, logits, sum_logprobs
)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.reset()
return tokens, sum_logprobs, no_speech_probs
def run(self, mel: mx.array) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
tokens: np.array = np.array(self.initial_tokens)
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy()
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(
audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
]
# repeat tokens by the group size, for beam search or best-of-n sampling
tokens = mx.array(tokens)
if self.n_group > 1:
tokens = tokens[:, None, :]
tokens = mx.broadcast_to(
tokens, [n_audio, self.n_group, len(self.initial_tokens)]
)
tokens = tokens.reshape(
tokens, (n_audio * self.n_group, len(self.initial_tokens))
)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens = tokens[..., self.sample_begin :].tolist()
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i] for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
]
def decode(
model: "Whisper",
mel: mx.array,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: mx.array, shape = (80, 3000) or (*, 80, 3000)
An array containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
if single := mel.ndim == 2:
mel = mel[None]
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result

View File

@ -0,0 +1,192 @@
import hashlib
import os
import urllib
import warnings
from typing import List
import mlx.core as mx
import torch
from tqdm import tqdm
from . import whisper
from . import torch_whisper
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
}
def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_torch_model(
name: str,
download_root: str = None,
) -> torch_whisper.Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name]
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with open(checkpoint_file, "rb") as fp:
checkpoint = torch.load(fp)
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
model = torch_whisper.Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model
def convert(model, rules=None):
params = {}
if rules is not None and type(model) in rules:
out = rules[type(model)](model, rules)
return out
if isinstance(model, torch.Tensor):
return mx.array(model.detach().numpy())
if isinstance(model, torch.nn.ModuleList):
return [convert(n, rules) for n in model.children()]
if isinstance(model, torch.nn.Conv1d):
return {
"weight": convert(model.weight).transpose(0, 2, 1),
"bias": convert(model.bias),
}
for k, n in model.named_children():
if k in rules:
params.update(rules[k](n, rules))
else:
params[k] = convert(n, rules)
for k, p in model.named_parameters(recurse=False):
params[k] = convert(p)
return params
def torch_to_mlx(
torch_model: torch_whisper.Whisper,
) -> whisper.Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
mlp = list(children.pop("mlp").children())
params = {
"mlp1": convert(mlp[0], rules),
"mlp2": convert(mlp[-1], rules),
}
for k, n in children.items():
params[k] = convert(n, rules)
return params
rules = {
torch_whisper.ResidualAttentionBlock: convert_rblock,
}
params = convert(torch_model, rules)
mlx_model = whisper.Whisper(torch_model.dims)
mlx_model.update(params)
return mlx_model
def load_model(
name: str,
download_root: str = None,
) -> whisper.Whisper:
return torch_to_mlx(load_torch_model(name, download_root))

385
whisper/whisper/timing.py Normal file
View File

@ -0,0 +1,385 @@
import itertools
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import numba
import numpy as np
import torch
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
# F.pad requires the padding width to be smaller than the input dimension
return x
if (ndim := x.ndim) <= 2:
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]
return result
@numba.jit(nopython=True)
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected trace[i, j]")
result = np.array(result)
return result[::-1, :].T
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
@dataclass
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float
def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
num_frames: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
# array([0.])
# This results in crashes when we lookup jump_times with float, like
# IndexError: arrays used as indices must be of integer (or boolean) type
return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1
# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1
def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
last_speech_timestamp: float,
**kwargs,
):
if len(segments) == 0:
return
text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(word_durations) > 0:
sentence_end_marks = ".。!?"
# ensure words at sentence boundaries are not longer than twice the median word duration.
for i in range(1, len(alignment)):
if alignment[i].end - alignment[i].start > max_duration:
if alignment[i].word in sentence_end_marks:
alignment[i].end = alignment[i].start + max_duration
elif alignment[i - 1].word in sentence_end_marks:
alignment[i].start = alignment[i].end - max_duration
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_index = 0
for segment, text_tokens in zip(segments, text_tokens_per_segment):
saved_tokens = 0
words = []
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
if timing.word:
words.append(
dict(
word=timing.word,
start=round(time_offset + timing.start, 2),
end=round(time_offset + timing.end, 2),
probability=timing.probability,
)
)
saved_tokens += len(timing.tokens)
word_index += 1
# hack: truncate long words at segment boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(words) > 0:
# ensure the first and second word after a pause is not longer than
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
or (
len(words) > 1
and words[1]["end"] - words[0]["start"] > max_duration * 2
)
):
if (
len(words) > 1
and words[1]["end"] - words[1]["start"] > max_duration
):
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)
# prefer the segment-level start timestamp if the first word is too long.
if (
segment["start"] < words[0]["end"]
and segment["start"] - 0.5 > words[0]["start"]
):
words[0]["start"] = max(
0, min(words[0]["end"] - median_duration, segment["start"])
)
else:
segment["start"] = words[0]["start"]
# prefer the segment-level end timestamp if the last word is too long.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
words[-1]["end"] = max(
words[-1]["start"] + median_duration, segment["end"]
)
else:
segment["end"] = words[-1]["end"]
last_speech_timestamp = segment["end"]
segment["words"] = words

View File

@ -0,0 +1,387 @@
import base64
import os
import string
from dataclasses import dataclass, field
from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import tiktoken
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
@dataclass
class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
return self.encoding.decode(token_ids, **kwargs)
@cached_property
def eot(self) -> int:
return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self.special_tokens["<|startoftranscript|>"]
@cached_property
def sot_lm(self) -> int:
return self.special_tokens["<|startoflm|>"]
@cached_property
def sot_prev(self) -> int:
return self.special_tokens["<|startofprev|>"]
@cached_property
def no_speech(self) -> int:
return self.special_tokens["<|nospeech|>"]
@cached_property
def no_timestamps(self) -> int:
return self.special_tokens["<|notimestamps|>"]
@cached_property
def timestamp_begin(self) -> int:
return self.special_tokens["<|0.00|>"]
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
if token := self.special_tokens.get(f"<|{self.language}|>", None):
return token
raise KeyError(f"Language {self.language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2"):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
with open(vocab_path) as fid:
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in fid if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
encoding = get_encoding(name=encoding_name)
return Tokenizer(encoding=encoding, language=language, task=task)

View File

@ -0,0 +1,301 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half layers for alignment by default; see `set_alignment_heads()` below
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks

View File

@ -0,0 +1,358 @@
import mlx.core as mx
import numpy as np
import sys
from typing import Optional, Tuple, Union
import tqdm
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .load_models import load_model
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
def _format_timestamp(seconds: float):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
class ModelHolder:
model = None
model_name = None
@classmethod
def get_model(cls, model: str):
if cls.model is None or model != cls.model_name:
cls.model = load_model(model)
cls.model_name = model
return cls.model
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
model: str = "tiny",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
audio: Union[str, np.ndarray, mx.array]
The path to the audio file to open, or the audio waveform
model: str
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
Default is "tiny".
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
model = ModelHolder.get_model(model)
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
content_frames = mel.shape[-2] - N_FRAMES
if verbose:
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
make_safe = lambda x: x.encode(system_encoding, errors="replace").decode(
system_encoding
)
else:
make_safe = lambda x: x
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. "
"Use the `language` decoding option to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES, axis=-2).astype(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
def decode_with_fallback(segment: mx.array) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
return decode_result
seek = 0
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: mx.array, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames:
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = np.array(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = np.where(
np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:])
)[0]
consecutive += 1
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)

214
whisper/whisper/whisper.py Normal file
View File

@ -0,0 +1,214 @@
import base64
import gzip
from dataclasses import dataclass
import math
from typing import Dict, Iterable, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2))
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def __call__(
self,
x,
xa=None,
mask=None,
kv_cache=None,
):
q = self.query(x)
if xa is None:
k = self.key(x)
v = self.value(x)
if kv_cache is not None:
k = mx.concatenate([kv_cache[0], k], axis=1)
v = mx.concatenate([kv_cache[1], v], axis=1)
elif kv_cache is None:
k = self.key(xa)
v = self.value(xa)
else:
k, v = kv_cache
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv), (k, v)
def qkv_attention(self, q, k, v, mask=None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale
k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale
v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp1 = nn.Linear(n_state, n_mlp)
self.mlp2 = nn.Linear(n_mlp, n_state)
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None)
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
x += y
if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
return x, (kv, cross_kv)
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self._positional_embedding = sinusoids(n_ctx, n_state)
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
self.ln_post = nn.LayerNorm(n_state)
def __call__(self, x):
x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x))
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding
for block in self.blocks:
x, _ = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = mx.zeros((n_ctx, n_state))
self.blocks = [
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
self.ln = nn.LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx)
def __call__(self, x, xa, kv_cache=None):
"""
x : mx.array, shape = (batch_size, <= n_ctx)
the text tokens
xa : mx.array, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = kv_cache[0][0][0].shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
if kv_cache is None:
kv_cache = [None] * len(self.blocks)
for e, block in enumerate(self.blocks):
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e])
x = self.ln(x)
return x @ self.token_embedding.weight.T, kv_cache
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
def embed_audio(self, mel):
return self.encoder(mel)
def logits(self, tokens, audio_features):
return self.decoder(tokens, audio_features)[0]
def __call__(self, mel, tokens):
return self.decoder(tokens, self.encoder(mel))[0]
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
detect_language = detect_language_function
decode = decode_function