mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-01 12:06:37 +08:00
a few examples
This commit is contained in:
parent
e31d82d3ed
commit
b243c1d8f4
129
.gitignore
vendored
Normal file
129
.gitignore
vendored
Normal file
@ -0,0 +1,129 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
5
.pre-commit-config.yaml
Normal file
5
.pre-commit-config.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
- id: black
|
16
README.md
16
README.md
@ -1,2 +1,14 @@
|
||||
# mlx-examples
|
||||
Examples in the MLX framework
|
||||
# Transformer LM
|
||||
|
||||
This is an example of a decoder-only Transformer LM. The only dependency is
|
||||
MLX.
|
||||
|
||||
Run the example on the GPU with:
|
||||
|
||||
```
|
||||
python main.py --gpu
|
||||
```
|
||||
|
||||
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
|
||||
|
||||
To run the PyTorch, Jax or TensorFlowexamples install the respective framework.
|
||||
|
18
mnist/README.md
Normal file
18
mnist/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# MNIST
|
||||
|
||||
This example shows how to run some simple models on MNIST. The only
|
||||
dependency is MLX.
|
||||
|
||||
Run the example with:
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
By default the example runs on the CPU. To run on the GPU, use:
|
||||
|
||||
```
|
||||
python main.py --gpu
|
||||
```
|
||||
|
||||
To run the PyTorch or Jax examples install the respective framework.
|
80
mnist/jax_main.py
Normal file
80
mnist/jax_main.py
Normal file
@ -0,0 +1,80 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import functools
|
||||
import time
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
def init_model(key, num_layers, input_dim, hidden_dim, output_dim):
|
||||
params = []
|
||||
layer_sizes = [hidden_dim] * num_layers
|
||||
for idim, odim in zip([input_dim] + layer_sizes, layer_sizes + [output_dim]):
|
||||
key, wk = jax.random.split(key, 2)
|
||||
W = 1e-2 * jax.random.normal(wk, (idim, odim))
|
||||
b = jnp.zeros((odim,))
|
||||
params.append((W, b))
|
||||
return params
|
||||
|
||||
|
||||
def feed_forward(params, X):
|
||||
for W, b in params[:-1]:
|
||||
X = jnp.maximum(X @ W + b, 0)
|
||||
W, b = params[-1]
|
||||
return X @ W + b
|
||||
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
logits = feed_forward(params, X)
|
||||
logits = jax.nn.log_softmax(logits, 1)
|
||||
return -jnp.mean(logits[jnp.arange(y.size), y])
|
||||
|
||||
|
||||
@jax.jit
|
||||
def eval_fn(params, X, y):
|
||||
logits = feed_forward(params, X)
|
||||
return jnp.mean(jnp.argmax(logits, axis=1) == y)
|
||||
|
||||
|
||||
def batch_iterate(key, batch_size, X, y):
|
||||
perm = jax.random.permutation(key, y.size)
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = mnist.mnist()
|
||||
|
||||
# Load the model
|
||||
key, subkey = jax.random.split(jax.random.PRNGKey(seed))
|
||||
params = init_model(
|
||||
subkey, num_layers, train_images.shape[-1], hidden_dim, num_classes
|
||||
)
|
||||
|
||||
loss_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
|
||||
update_fn = jax.jit(
|
||||
functools.partial(jax.tree_map, lambda p, g: p - learning_rate * g)
|
||||
)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
key, subkey = jax.random.split(key)
|
||||
for X, y in batch_iterate(subkey, batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(params, X, y)
|
||||
params = update_fn(params, grads)
|
||||
accuracy = eval_fn(params, test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
f" Time {toc - tic:.3f} (s)"
|
||||
)
|
88
mnist/main.py
Normal file
88
mnist/main.py
Normal file
@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = mx.maximum(l(x), 0.0)
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
return mx.mean(nn.losses.cross_entropy(model(X), y))
|
||||
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
def main():
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
f" Time {toc - tic:.3f} (s)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if not args.gpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main()
|
67
mnist/mnist.py
Normal file
67
mnist/mnist.py
Normal file
@ -0,0 +1,67 @@
|
||||
import gzip
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
from urllib import request
|
||||
|
||||
|
||||
def mnist(save_dir="/tmp"):
|
||||
"""
|
||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||
test images, and test labels.
|
||||
|
||||
Checks `save_dir` for already downloaded data otherwise downloads.
|
||||
|
||||
Download code modified from:
|
||||
https://github.com/hsjeong5/MNIST-for-Numpy
|
||||
"""
|
||||
|
||||
def download_and_save(save_file):
|
||||
base_url = "http://yann.lecun.com/exdb/mnist/"
|
||||
filename = [
|
||||
["training_images", "train-images-idx3-ubyte.gz"],
|
||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||
["training_labels", "train-labels-idx1-ubyte.gz"],
|
||||
["test_labels", "t10k-labels-idx1-ubyte.gz"],
|
||||
]
|
||||
|
||||
mnist = {}
|
||||
for name in filename:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
request.urlretrieve(base_url + name[1], out_file)
|
||||
for name in filename[:2]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
|
||||
-1, 28 * 28
|
||||
)
|
||||
for name in filename[-2:]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with open(save_file, "wb") as f:
|
||||
pickle.dump(mnist, f)
|
||||
|
||||
save_file = os.path.join(save_dir, "mnist.pkl")
|
||||
if not os.path.exists(save_file):
|
||||
download_and_save(save_file)
|
||||
with open(save_file, "rb") as f:
|
||||
mnist = pickle.load(f)
|
||||
|
||||
preproc = lambda x: x.astype(np.float32) / 255.0
|
||||
mnist["training_images"] = preproc(mnist["training_images"])
|
||||
mnist["test_images"] = preproc(mnist["test_images"])
|
||||
return (
|
||||
mnist["training_images"],
|
||||
mnist["training_labels"].astype(np.uint32),
|
||||
mnist["test_images"],
|
||||
mnist["test_labels"].astype(np.uint32),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_x, train_y, test_x, test_y = mnist()
|
||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||
assert train_y.shape == (60000,), "Wrong training set size"
|
||||
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
|
||||
assert test_y.shape == (10000,), "Wrong test set size"
|
88
mnist/torch_main.py
Normal file
88
mnist/torch_main.py
Normal file
@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
|
||||
import mnist
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
layer_sizes = [hidden_dim] * num_layers
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(idim, odim)
|
||||
for idim, odim in zip(
|
||||
[input_dim] + layer_sizes, layer_sizes + [output_dim]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layers[0](x)
|
||||
for l in self.layers[1:]:
|
||||
x = l(x.relu())
|
||||
return x
|
||||
|
||||
|
||||
def loss_fn(model, X, y):
|
||||
logits = model(X)
|
||||
return torch.nn.functional.cross_entropy(logits, y)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_fn(model, X, y):
|
||||
logits = model(X)
|
||||
return torch.mean((logits.argmax(-1) == y).float())
|
||||
|
||||
|
||||
def batch_iterate(batch_size, X, y, device):
|
||||
perm = torch.randperm(len(y), device=device)
|
||||
for s in range(0, len(y), batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield X[ids], y[ids]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple MLP on MNIST with PyTorch.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.gpu:
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "mps"
|
||||
seed = 0
|
||||
num_layers = 2
|
||||
hidden_dim = 32
|
||||
num_classes = 10
|
||||
batch_size = 256
|
||||
num_epochs = 10
|
||||
learning_rate = 1e-1
|
||||
|
||||
# Load the data
|
||||
def to_tensor(x):
|
||||
if x.dtype != "uint32":
|
||||
return torch.from_numpy(x).to(device)
|
||||
else:
|
||||
return torch.from_numpy(x.astype(int)).to(device)
|
||||
|
||||
train_images, train_labels, test_images, test_labels = map(to_tensor, mnist.mnist())
|
||||
|
||||
# Load the model
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes).to(device)
|
||||
opt = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels, device):
|
||||
opt.zero_grad()
|
||||
loss_fn(model, X, y).backward()
|
||||
opt.step()
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
f" Time {toc - tic:.3f} (s)"
|
||||
)
|
14
transformer_lm/README.md
Normal file
14
transformer_lm/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# Transformer LM
|
||||
|
||||
This is an example of a decoder-only Transformer LM. The only dependency is
|
||||
MLX.
|
||||
|
||||
Run the example on the GPU with:
|
||||
|
||||
```
|
||||
python main.py --gpu
|
||||
```
|
||||
|
||||
By default the dataset is the [PTB corpus](https://paperswithcode.com/dataset/penn-treebank). Choose a different dataset with the `--dataset` option.
|
||||
|
||||
To run the PyTorch, Jax or TensorFlowexamples install the respective framework.
|
90
transformer_lm/datasets.py
Normal file
90
transformer_lm/datasets.py
Normal file
@ -0,0 +1,90 @@
|
||||
import io
|
||||
import itertools
|
||||
import numpy as np
|
||||
import os
|
||||
from urllib import request
|
||||
import zipfile
|
||||
|
||||
|
||||
def load_dataset(dataname):
|
||||
if dataname == "ptb":
|
||||
return ptb()
|
||||
elif dataname == "wikitext2":
|
||||
return wikitext(dataset="2")
|
||||
else:
|
||||
return wikitext(dataset="103")
|
||||
|
||||
|
||||
def _load(save_dir, filenames):
|
||||
# *NB* First file is expected to be the training set
|
||||
with open(os.path.join(save_dir, filenames[0]), "r") as fid:
|
||||
vocab = set(t for l in fid.readlines() for t in l.strip().split(" "))
|
||||
eos = "<eos>"
|
||||
vocab.add(eos)
|
||||
vocab = {v: i for i, v in enumerate(vocab)}
|
||||
|
||||
def to_array(dataset):
|
||||
with open(os.path.join(save_dir, dataset), "r") as fid:
|
||||
lines = (l.strip().split(" ") for l in fid.readlines())
|
||||
return np.array(
|
||||
[vocab[w] for line in lines for w in itertools.chain(line, [eos])],
|
||||
dtype=np.uint32,
|
||||
)
|
||||
|
||||
datasets = [to_array(fn) for fn in filenames]
|
||||
return vocab, *datasets
|
||||
|
||||
|
||||
def wikitext(dataset="2", save_dir="/tmp"):
|
||||
"""
|
||||
Load the WikiText-* language modeling dataset:
|
||||
https://paperswithcode.com/dataset/penn-treebank
|
||||
"""
|
||||
if dataset not in ("2", "103"):
|
||||
raise ValueError(f'Dataset must be either "2" or "103", got {dataset}')
|
||||
|
||||
filenames = ["wiki.train.tokens", "wiki.valid.tokens", "wiki.test.tokens"]
|
||||
dataname = f"wikitext-{dataset}"
|
||||
data_dir = os.path.join(save_dir, dataname)
|
||||
if not os.path.exists(data_dir):
|
||||
base_url = "https://s3.amazonaws.com/research.metamind.io/wikitext/"
|
||||
zip_file_url = base_url + dataname + "-v1.zip"
|
||||
r = request.urlopen(zip_file_url)
|
||||
with zipfile.ZipFile(io.BytesIO(r.read())) as zf:
|
||||
zf.extractall(save_dir)
|
||||
|
||||
return _load(data_dir, filenames)
|
||||
|
||||
|
||||
def ptb(save_dir="/tmp"):
|
||||
"""
|
||||
Load the PTB language modeling dataset:
|
||||
https://paperswithcode.com/dataset/penn-treebank
|
||||
"""
|
||||
filenames = [
|
||||
"ptb.train.txt",
|
||||
"ptb.valid.txt",
|
||||
"ptb.test.txt",
|
||||
]
|
||||
|
||||
def download_and_save(save_dir):
|
||||
base_url = "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/"
|
||||
for name in filenames:
|
||||
out_file = os.path.join(save_dir, name)
|
||||
if not os.path.exists(out_file):
|
||||
request.urlretrieve(base_url + name, out_file)
|
||||
|
||||
save_dir = os.path.join(save_dir, "ptb")
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
download_and_save(save_dir)
|
||||
|
||||
return _load(save_dir, filenames)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vocab, train, val, test = ptb()
|
||||
assert len(vocab) == 10000, "PTB: Wrong vocab size"
|
||||
|
||||
vocab, train, val, test = wikitext()
|
||||
assert len(vocab) == 33279, "WikiText: Wrong vocab size"
|
303
transformer_lm/jax_main.py
Normal file
303
transformer_lm/jax_main.py
Normal file
@ -0,0 +1,303 @@
|
||||
import functools
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import numpy as np
|
||||
import time
|
||||
from collections import namedtuple
|
||||
|
||||
import datasets
|
||||
from tree_utils import tree_flatten
|
||||
|
||||
"""
|
||||
Some TODOs for this model:
|
||||
- Positional encodings
|
||||
- Dropout
|
||||
- Adam optimizer
|
||||
- Option for bigger datasets (wikitext / librispeech text < c4 < ...)
|
||||
"""
|
||||
|
||||
RuntimeConfig = namedtuple("RuntimeConfig", "num_heads")
|
||||
|
||||
|
||||
def embedding_init(key, num_embeddings, embed_dim):
|
||||
return jax.random.uniform(
|
||||
key, (num_embeddings, embed_dim), minval=-1e-1, maxval=1e-1
|
||||
)
|
||||
|
||||
|
||||
def embedding_apply(params, X):
|
||||
return params.take(X, axis=0)
|
||||
|
||||
|
||||
def dense_init(key, in_dim, out_dim, bias=True):
|
||||
k1, k2 = jax.random.split(key)
|
||||
scale = math.sqrt(1 / in_dim)
|
||||
params = [jax.random.uniform(k1, (in_dim, out_dim), minval=-scale, maxval=scale)]
|
||||
if bias:
|
||||
params.append(jax.random.uniform(k2, (out_dim,), minval=-scale, maxval=scale))
|
||||
return params
|
||||
|
||||
|
||||
def dense_apply(params, X):
|
||||
X = X @ params[0]
|
||||
if len(params) == 2:
|
||||
X = X + params[1]
|
||||
return X
|
||||
|
||||
|
||||
def layernorm_init(key, dim):
|
||||
return [jnp.zeros((dim,)), jnp.ones((dim,))]
|
||||
|
||||
|
||||
def layernorm_apply(params, X, epsilon=1e-6):
|
||||
means = jnp.mean(X, axis=-1, keepdims=True)
|
||||
var = jnp.var(X, axis=-1, keepdims=True)
|
||||
X = (X - means) / jnp.sqrt(var + epsilon)
|
||||
beta, gamma = params
|
||||
return gamma * X + beta
|
||||
|
||||
|
||||
def mlpblock_init(key, dim):
|
||||
k1, k2 = jax.random.split(key)
|
||||
return {
|
||||
"dense1": dense_init(k1, dim, 4 * dim),
|
||||
"dense2": dense_init(k2, 4 * dim, dim),
|
||||
}
|
||||
|
||||
|
||||
def mlpblock_apply(params, X):
|
||||
X = dense_apply(params["dense1"], X)
|
||||
X = jnp.maximum(X, 0)
|
||||
# TODO dropout option here
|
||||
return dense_apply(params["dense2"], X)
|
||||
|
||||
|
||||
def selfattention_init(key, dim):
|
||||
k1, k2, k3, k4 = jax.random.split(key, 4)
|
||||
return {
|
||||
"Q": dense_init(k1, dim, dim, bias=False),
|
||||
"K": dense_init(k2, dim, dim, bias=False),
|
||||
"V": dense_init(k3, dim, dim, bias=False),
|
||||
"out": dense_init(k4, dim, dim, bias=False),
|
||||
}
|
||||
|
||||
|
||||
def selfattention_apply(params, num_heads, X, mask):
|
||||
queries = dense_apply(params["Q"], X)
|
||||
keys = dense_apply(params["K"], X)
|
||||
values = dense_apply(params["V"], X)
|
||||
|
||||
B, L, D = queries.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ jnp.transpose(keys, (0, 1, 3, 2))
|
||||
scores = jax.nn.softmax(scores + mask, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return dense_apply(params["out"], values_hat)
|
||||
|
||||
|
||||
def transformer_init(key, token_set_size, num_blocks, dim):
|
||||
key, ek = jax.random.split(key)
|
||||
params = {"embedding": embedding_init(ek, token_set_size, dim)}
|
||||
transformer_blocks = []
|
||||
for b in range(num_blocks):
|
||||
key, k1, k2, k3, k4 = jax.random.split(key, 5)
|
||||
transformer_blocks.append(
|
||||
{
|
||||
"ln1": layernorm_init(k1, dim),
|
||||
"ln2": layernorm_init(k2, dim),
|
||||
"selfattention": selfattention_init(k3, dim),
|
||||
"mlpblock": mlpblock_init(k4, dim),
|
||||
}
|
||||
)
|
||||
params["transformer_blocks"] = transformer_blocks
|
||||
params["output"] = dense_init(key, dim, token_set_size)
|
||||
return params
|
||||
|
||||
|
||||
def create_additive_causal_mask(N):
|
||||
indices = jnp.arange(N)
|
||||
mask = jnp.reshape(indices, (-1, 1)) < jnp.reshape(indices, (1, -1))
|
||||
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||
mask = mask.astype(jnp.float32) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
def transformer_apply(params, static_params, inputs):
|
||||
mask = create_additive_causal_mask(inputs.shape[1])
|
||||
X = embedding_apply(params["embedding"], inputs)
|
||||
for block in params["transformer_blocks"]:
|
||||
out = layernorm_apply(block["ln1"], X)
|
||||
out = selfattention_apply(
|
||||
block["selfattention"], static_params.num_heads, out, mask
|
||||
)
|
||||
X = X + out
|
||||
out = layernorm_apply(block["ln2"], X)
|
||||
out = mlpblock_apply(block["mlpblock"], out)
|
||||
X = X + out
|
||||
return dense_apply(params["output"], X)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["static_params", "reduce"])
|
||||
def loss_fn(params, static_params, inputs, targets, reduce=True):
|
||||
logits = transformer_apply(params, static_params, inputs)
|
||||
logits = jax.nn.log_softmax(logits, axis=-1)
|
||||
sample_indices = jnp.arange(targets.shape[0])[:, None]
|
||||
token_indices = jnp.arange(targets.shape[1])[None, :]
|
||||
losses = -logits[sample_indices, token_indices, targets]
|
||||
return jnp.mean(losses) if reduce else losses.mean(-1)
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
window_size = context_size + 1 # include target
|
||||
samples = tokens - window_size + 1
|
||||
X = np.lib.stride_tricks.as_strided(
|
||||
dataset,
|
||||
shape=(samples, window_size),
|
||||
strides=(dataset.itemsize, dataset.itemsize),
|
||||
)
|
||||
return X[:, :-1], X[:, 1:]
|
||||
|
||||
|
||||
def iterate_batches(key, batch_size, context_size, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
s = 0
|
||||
while True:
|
||||
if s == 0:
|
||||
# Reset permutation:
|
||||
key, subkey = jax.random.split(key)
|
||||
perm = jax.random.permutation(subkey, inputs.shape[0])
|
||||
ids = perm[s : s + batch_size]
|
||||
yield inputs[ids], targets[ids]
|
||||
s += batch_size
|
||||
if s >= inputs.shape[0]:
|
||||
s = 0
|
||||
|
||||
|
||||
def main(args):
|
||||
batch_size = args.batch_size
|
||||
context_size = args.context_size
|
||||
steps_per_eval = args.steps_per_eval
|
||||
steps_per_report = args.steps_per_report
|
||||
config = RuntimeConfig(args.num_heads)
|
||||
|
||||
# Load vocab and dataset:
|
||||
vocab, train, valid, test = datasets.ptb()
|
||||
|
||||
# Initialize model:
|
||||
key, subkey = jax.random.split(jax.random.PRNGKey(args.seed))
|
||||
params = transformer_init(subkey, len(vocab), args.num_blocks, args.dim)
|
||||
nparams = sum(x.size for k, x in tree_flatten(params) if "embedding" not in k)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
loss_and_grad_fn = jax.jit(
|
||||
jax.value_and_grad(loss_fn), static_argnames=["static_params"]
|
||||
)
|
||||
update_fn = jax.jit(
|
||||
functools.partial(jax.tree_map, lambda p, g: p - args.learning_rate * g)
|
||||
)
|
||||
|
||||
def eval_fn(params, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
losses = loss_fn(params, config, bx, by, reduce=False)
|
||||
loss += jnp.sum(losses)
|
||||
return loss / len(targets)
|
||||
|
||||
train_iterator = iterate_batches(subkey, batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
loss, grads = loss_and_grad_fn(params, config, inputs, targets)
|
||||
losses.append(loss.item())
|
||||
params = update_fn(params, grads)
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
||||
)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
|
||||
if (it + 1) % steps_per_eval == 0:
|
||||
val_loss = eval_fn(params, valid)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
||||
f"Val took {(toc - tic):.3f}s, "
|
||||
)
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.eval_test:
|
||||
test_loss = eval_fn(params, test)
|
||||
test_ppl = math.exp(test_loss)
|
||||
print(f"Test loss {test_loss.item():.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with Jax.")
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=0, help="Seed for numpy and Jax RNGs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Context size in tokens of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Dimensionality of embeddings and hidden layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_heads",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of heads used for multi-head attention",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_eval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
190
transformer_lm/main.py
Normal file
190
transformer_lm/main.py
Normal file
@ -0,0 +1,190 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
class TransformerLM(nn.Module):
|
||||
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, dims)
|
||||
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
|
||||
self.out_proj = nn.Linear(dims, vocab_size)
|
||||
|
||||
def __call__(self, x):
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
x = self.embedding(x)
|
||||
x = self.transformer(x, mask)
|
||||
return self.out_proj(x)
|
||||
|
||||
def loss(self, x, y, reduce=True):
|
||||
logits = self(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
mx.simplify(losses)
|
||||
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
window_size = context_size + 1 # include target
|
||||
samples = tokens - window_size + 1
|
||||
X = np.lib.stride_tricks.as_strided(
|
||||
dataset,
|
||||
shape=(samples, window_size),
|
||||
strides=(dataset.itemsize, dataset.itemsize),
|
||||
)
|
||||
return X[:, :-1], X[:, 1:]
|
||||
|
||||
|
||||
def iterate_batches(batch_size, context_size, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
s = 0
|
||||
while True:
|
||||
if s == 0:
|
||||
# Reset permutation:
|
||||
perm = np.random.permutation(inputs.shape[0])
|
||||
ids = perm[s : s + batch_size]
|
||||
yield inputs[ids], targets[ids]
|
||||
s += batch_size
|
||||
if s >= inputs.shape[0]:
|
||||
s = 0
|
||||
|
||||
|
||||
def main(args):
|
||||
batch_size = args.batch_size
|
||||
context_size = args.context_size
|
||||
steps_per_eval = args.steps_per_eval
|
||||
steps_per_report = args.steps_per_report
|
||||
|
||||
# Load vocab and dataset:
|
||||
vocab, train, valid, test = datasets.load_dataset(args.dataset)
|
||||
|
||||
# Initialize model:
|
||||
model = TransformerLM(len(vocab), args.num_blocks, args.dim, args.num_heads)
|
||||
mx.eval(model.parameters())
|
||||
nparams = sum(
|
||||
x.size for k, x in tree_flatten(model.parameters()) if "embedding" not in k
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
optimizer = optim.SGD(learning_rate=args.learning_rate)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||
|
||||
def eval_fn(params, dataset):
|
||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = self.loss(bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||
model.update(optimizer.apply_gradients(grads, model))
|
||||
mx.simplify(loss, model.parameters())
|
||||
mx.eval(loss, model.parameters())
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
||||
)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
|
||||
if (it + 1) % steps_per_eval == 0:
|
||||
val_loss = eval_fn(params, valid)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
||||
f"Val took {(toc - tic):.3f}s, "
|
||||
)
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.eval_test:
|
||||
test_loss = eval_fn(params, test)
|
||||
test_ppl = math.exp(test_loss)
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="ptb",
|
||||
choices=["ptb", "wikitext2", "wikitext103"],
|
||||
help="Dataset to train and evaluate on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Context size in tokens of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Dimensionality of embeddings and hidden layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_heads",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of heads used for multi-head attention",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_eval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not args.gpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main(args)
|
249
transformer_lm/tf_main.py
Normal file
249
transformer_lm/tf_main.py
Normal file
@ -0,0 +1,249 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
window_size = context_size + 1 # include target
|
||||
samples = tokens - window_size + 1
|
||||
X = np.lib.stride_tricks.as_strided(
|
||||
dataset,
|
||||
shape=(samples, window_size),
|
||||
strides=(dataset.itemsize, dataset.itemsize),
|
||||
)
|
||||
return X[:, :-1], X[:, 1:]
|
||||
|
||||
|
||||
def iterate_batches(batch_size, context_size, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
s = 0
|
||||
while True:
|
||||
if s == 0:
|
||||
# Reset permutation:
|
||||
perm = np.random.permutation(inputs.shape[0])
|
||||
ids = perm[s : s + batch_size]
|
||||
yield inputs[ids], targets[ids]
|
||||
s += batch_size
|
||||
if s + batch_size >= inputs.shape[0]:
|
||||
s = 0
|
||||
|
||||
|
||||
def create_additive_causal_mask(N):
|
||||
indices = tf.range(N)
|
||||
mask = tf.reshape(indices, (-1, 1)) < tf.reshape(indices, (1, -1))
|
||||
return tf.cast(mask, tf.dtypes.float32) * -1e9
|
||||
|
||||
|
||||
class SelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, num_heads, model_dims, context_size):
|
||||
super().__init__()
|
||||
self.Wq = tf.keras.layers.Dense(model_dims, use_bias=False)
|
||||
self.Wk = tf.keras.layers.Dense(model_dims, use_bias=False)
|
||||
self.Wv = tf.keras.layers.Dense(model_dims, use_bias=False)
|
||||
self.Wo = tf.keras.layers.Dense(model_dims, use_bias=False)
|
||||
self.causal_mask = create_additive_causal_mask(context_size)
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = model_dims // num_heads
|
||||
self.scale = tf.constant(1.0 / math.sqrt(self.head_dim))
|
||||
|
||||
def call(self, x):
|
||||
queries = self.Wq(x)
|
||||
keys = self.Wk(x)
|
||||
values = self.Wv(x)
|
||||
|
||||
B, L, D = x.shape
|
||||
queries = tf.transpose(
|
||||
tf.reshape(queries, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
|
||||
)
|
||||
keys = tf.transpose(
|
||||
tf.reshape(keys, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
|
||||
)
|
||||
values = tf.transpose(
|
||||
tf.reshape(values, (B, L, self.num_heads, -1)), perm=(0, 2, 1, 3)
|
||||
)
|
||||
|
||||
scores = (self.scale * queries) @ tf.transpose(keys, (0, 1, 3, 2))
|
||||
scores = tf.nn.softmax(scores + self.causal_mask, axis=-1)
|
||||
values = tf.matmul(scores, values)
|
||||
values_hat = tf.reshape(tf.transpose(values, perm=(0, 2, 1, 3)), (B, L, -1))
|
||||
|
||||
return self.Wo(values_hat)
|
||||
|
||||
|
||||
class EncoderLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, num_heads, model_dims, context_size):
|
||||
super().__init__()
|
||||
self._ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
|
||||
|
||||
self._self_attn = SelfAttention(num_heads, model_dims, context_size)
|
||||
|
||||
self._ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
|
||||
|
||||
self._mlp = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.layers.Dense(4 * model_dims, activation="relu"),
|
||||
tf.keras.layers.Dense(model_dims),
|
||||
]
|
||||
)
|
||||
|
||||
def call(self, x):
|
||||
x = x + self._self_attn(self._ln1(x))
|
||||
x = x + self._mlp(self._ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerLM(tf.keras.Model):
|
||||
def __init__(self, vocab_size, num_layers, num_heads, model_dims, context_size):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = tf.keras.layers.Embedding(vocab_size, model_dims)
|
||||
self.transformer = tf.keras.Sequential(
|
||||
[
|
||||
EncoderLayer(num_heads, model_dims, context_size)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.projection = tf.keras.layers.Dense(vocab_size)
|
||||
|
||||
def call(self, x):
|
||||
x = self.embedding(x)
|
||||
x = self.transformer(x)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args, device):
|
||||
with tf.device(device):
|
||||
batch_size = args.batch_size
|
||||
context_size = args.context_size
|
||||
steps_per_eval = args.steps_per_eval
|
||||
steps_per_report = args.steps_per_report
|
||||
|
||||
# Load vocab and dataset:
|
||||
vocab, train, valid, test = datasets.ptb()
|
||||
|
||||
# Initialize model:
|
||||
transformer = TransformerLM(
|
||||
len(vocab), args.num_blocks, args.num_heads, args.dim, context_size
|
||||
)
|
||||
transformer.compile(
|
||||
optimizer=tf.keras.optimizers.legacy.SGD(learning_rate=args.learning_rate),
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
)
|
||||
transformer.build((batch_size, context_size))
|
||||
nparams = sum(
|
||||
np.prod(p.shape) for p in transformer.trainable_weights[1:]
|
||||
) # [1:] to skip the embedding
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
def eval_fn(dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
loss = 0
|
||||
n_batches = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
if s + batch_size >= targets.shape[0]:
|
||||
s = targets.shape[0] - 1 - batch_size
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(
|
||||
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
|
||||
[bx, by],
|
||||
)
|
||||
loss += transformer.test_on_batch(bx, by)
|
||||
n_batches += 1
|
||||
return loss / n_batches
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(
|
||||
lambda x: tf.convert_to_tensor(x, dtype=tf.dtypes.int32),
|
||||
[inputs, targets],
|
||||
)
|
||||
loss = transformer.train_on_batch(inputs, targets)
|
||||
losses.append(loss)
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
||||
)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
|
||||
if (it + 1) % steps_per_eval == 0:
|
||||
val_loss = eval_fn(valid)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
||||
f"Val took {(toc - tic):.3f}s, "
|
||||
)
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.eval_test:
|
||||
test_loss = eval_fn(test)
|
||||
test_ppl = math.exp(test_loss)
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
|
||||
parser.add_argument(
|
||||
"--context_size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Context size in tokens of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Dimensionality of embeddings and hidden layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_heads",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of heads used for multi-head attention",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_eval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args, device="/GPU:0" if args.gpu else "/CPU:0")
|
197
transformer_lm/torch_main.py
Normal file
197
transformer_lm/torch_main.py
Normal file
@ -0,0 +1,197 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
window_size = context_size + 1 # include target
|
||||
samples = tokens - window_size + 1
|
||||
X = np.lib.stride_tricks.as_strided(
|
||||
dataset,
|
||||
shape=(samples, window_size),
|
||||
strides=(dataset.itemsize, dataset.itemsize),
|
||||
)
|
||||
return X[:, :-1], X[:, 1:]
|
||||
|
||||
|
||||
def iterate_batches(batch_size, context_size, dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
s = 0
|
||||
while True:
|
||||
if s == 0:
|
||||
# Reset permutation:
|
||||
perm = np.random.permutation(inputs.shape[0])
|
||||
ids = perm[s : s + batch_size]
|
||||
yield inputs[ids], targets[ids]
|
||||
s += batch_size
|
||||
if s >= inputs.shape[0]:
|
||||
s = 0
|
||||
|
||||
|
||||
def create_additive_causal_mask(N, device):
|
||||
# torch.nn.Transformer.generate_square_subsequent_mask
|
||||
# gives NaNs with `device="mps"`
|
||||
indices = torch.arange(N, device=device)
|
||||
mask = indices.reshape((-1, 1)) < indices.reshape((1, -1))
|
||||
return mask.to(torch.float32) * -1e9
|
||||
|
||||
|
||||
class TransformerLM(torch.nn.Module):
|
||||
def __init__(self, vocab_size, num_layers, num_heads, model_dims):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = torch.nn.Embedding(vocab_size, model_dims)
|
||||
self.transformer = torch.nn.TransformerEncoder(
|
||||
torch.nn.TransformerEncoderLayer(
|
||||
model_dims,
|
||||
num_heads,
|
||||
4 * model_dims,
|
||||
dropout=0.0,
|
||||
norm_first=True,
|
||||
batch_first=True,
|
||||
),
|
||||
num_layers,
|
||||
)
|
||||
self.projection = torch.nn.Linear(model_dims, vocab_size)
|
||||
|
||||
def forward(self, x):
|
||||
mask = create_additive_causal_mask(x.shape[1], device=x.device)
|
||||
x = self.embedding(x)
|
||||
x = self.transformer(x, mask=mask)
|
||||
x = self.projection(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args, device):
|
||||
batch_size = args.batch_size
|
||||
context_size = args.context_size
|
||||
steps_per_eval = args.steps_per_eval
|
||||
steps_per_report = args.steps_per_report
|
||||
|
||||
# Load vocab and dataset:
|
||||
vocab, train, valid, test = datasets.ptb()
|
||||
|
||||
# Initialize model:
|
||||
transformer = TransformerLM(len(vocab), args.num_blocks, args.num_heads, args.dim)
|
||||
transformer = transformer.to(device)
|
||||
optim = torch.optim.SGD(transformer.parameters(), lr=args.learning_rate, momentum=0)
|
||||
nparams = sum(
|
||||
p.numel() for n, p in transformer.named_parameters() if "embedding" not in n
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_fn(dataset):
|
||||
inputs, targets = to_samples(context_size, dataset)
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(lambda x: torch.from_numpy(x.astype(int)).to(device), [bx, by])
|
||||
logits = transformer(bx)
|
||||
losses = torch.nn.functional.cross_entropy(
|
||||
logits.flatten(0, 1), by.flatten(), reduction="none"
|
||||
)
|
||||
losses = losses.view(-1, by.shape[-1]).mean(-1)
|
||||
loss += losses.sum().item()
|
||||
return loss / len(targets)
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(
|
||||
lambda x: torch.from_numpy(x.astype(int)).to(device), [inputs, targets]
|
||||
)
|
||||
optim.zero_grad()
|
||||
logits = transformer(inputs)
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
logits.flatten(0, 1), targets.flatten()
|
||||
)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
|
||||
f"It/sec {steps_per_report / (toc - tic):.3f}"
|
||||
)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
|
||||
if (it + 1) % steps_per_eval == 0:
|
||||
val_loss = eval_fn(valid)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val ppl {math.exp(val_loss):.3f}, "
|
||||
f"Val took {(toc - tic):.3f}s, "
|
||||
)
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.eval_test:
|
||||
test_loss = eval_fn(test)
|
||||
test_ppl = math.exp(test_loss)
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the RNGs.")
|
||||
parser.add_argument(
|
||||
"--context_size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Context size in tokens of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_blocks", type=int, default=12, help="Number of Transformer blocks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Dimensionality of embeddings and hidden layers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_heads",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of heads used for multi-head attention",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Minibatch size.")
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=100000, help="Iterations to train for."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_report",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_eval",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args, device="mps" if args.gpu else "cpu")
|
27
whisper/README.md
Normal file
27
whisper/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# whisper
|
||||
|
||||
Whisper in MLX.
|
||||
|
||||
First install the dependencies:
|
||||
|
||||
(TODO, MLX install link / command / add to requirements.txt)
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Install [`ffmpeg`](https://ffmpeg.org/):
|
||||
|
||||
```bash
|
||||
# on MacOS using Homebrew (https://brew.sh/)
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
Then transcribe audio with:
|
||||
|
||||
```
|
||||
import whisper
|
||||
|
||||
text = whisper.transcribe(speech_file)["text"]
|
||||
```
|
||||
|
91
whisper/benchmark.py
Normal file
91
whisper/benchmark.py
Normal file
@ -0,0 +1,91 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from whisper import load_models
|
||||
from whisper import audio
|
||||
from whisper import decoding
|
||||
from whisper import transcribe
|
||||
|
||||
audio_file = "whisper/assets/ls_test.flac"
|
||||
|
||||
|
||||
def timer(fn, *args):
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
|
||||
num_its = 10
|
||||
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_its):
|
||||
fn(*args)
|
||||
toc = time.perf_counter()
|
||||
return (toc - tic) / num_its
|
||||
|
||||
|
||||
def feats():
|
||||
data = audio.load_audio(audio_file)
|
||||
data = audio.pad_or_trim(data)
|
||||
mels = audio.log_mel_spectrogram(data)
|
||||
mx.eval(mels)
|
||||
return mels
|
||||
|
||||
|
||||
def model_forward(model, mels, tokens):
|
||||
logits = model(mels, tokens)
|
||||
mx.eval(logits)
|
||||
return logits
|
||||
|
||||
|
||||
def decode(model, mels):
|
||||
return decoding.decode(model, mels)
|
||||
|
||||
|
||||
def everything():
|
||||
return transcribe(audio_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
feat_time = timer(feats)
|
||||
print(f"Feature time {feat_time:.3f}")
|
||||
mels = feats()[None]
|
||||
tokens = mx.array(
|
||||
[
|
||||
50364,
|
||||
1396,
|
||||
264,
|
||||
665,
|
||||
5133,
|
||||
23109,
|
||||
25462,
|
||||
264,
|
||||
6582,
|
||||
293,
|
||||
750,
|
||||
632,
|
||||
42841,
|
||||
292,
|
||||
370,
|
||||
938,
|
||||
294,
|
||||
4054,
|
||||
293,
|
||||
12653,
|
||||
356,
|
||||
50620,
|
||||
50620,
|
||||
23563,
|
||||
322,
|
||||
3312,
|
||||
13,
|
||||
50680,
|
||||
],
|
||||
mx.int32,
|
||||
)[None]
|
||||
model = load_models.load_model("tiny")
|
||||
model_forward_time = timer(model_forward, model, mels, tokens)
|
||||
print(f"Model forward time {model_forward_time:.3f}")
|
||||
decode_time = timer(decode, model, mels)
|
||||
print(f"Decode time {decode_time:.3f}")
|
||||
everything_time = timer(everything)
|
||||
print(f"Everything time {everything_time:.3f}")
|
6
whisper/requirements.txt
Normal file
6
whisper/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
tqdm
|
||||
more-itertools
|
||||
tiktoken==0.3.3
|
270
whisper/test.py
Normal file
270
whisper/test.py
Normal file
@ -0,0 +1,270 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
import whisper.audio as audio
|
||||
import whisper.load_models as load_models
|
||||
import whisper.torch_whisper as torch_whisper
|
||||
import whisper.decoding as decoding
|
||||
|
||||
|
||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
||||
|
||||
|
||||
def forward_torch(model, mels, tokens):
|
||||
mels = torch.Tensor(mels).to(torch.float32)
|
||||
tokens = torch.Tensor(tokens).to(torch.int32)
|
||||
with torch.no_grad():
|
||||
logits = model.forward(mels, tokens)
|
||||
return logits.numpy()
|
||||
|
||||
|
||||
def forward_mlx(model, mels, tokens):
|
||||
mels = mx.array(mels.transpose(0, 2, 1))
|
||||
tokens = mx.array(tokens, mx.int32)
|
||||
logits = model(mels, tokens)
|
||||
return np.array(logits)
|
||||
|
||||
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = load_models.load_model("tiny")
|
||||
data = audio.load_audio(TEST_AUDIO)
|
||||
data = audio.pad_or_trim(data)
|
||||
cls.mels = audio.log_mel_spectrogram(data)
|
||||
|
||||
def test_torch_mlx(self):
|
||||
np.random.seed(10)
|
||||
|
||||
torch_model = load_models.load_torch_model("tiny")
|
||||
dims = torch_model.dims
|
||||
|
||||
mels = np.random.randn(1, dims.n_mels, 3_000)
|
||||
tokens = np.random.randint(0, dims.n_vocab, (1, 20))
|
||||
|
||||
torch_logits = forward_torch(torch_model, mels, tokens)
|
||||
|
||||
mlx_model = load_models.torch_to_mlx(torch_model)
|
||||
mlx_logits = forward_mlx(mlx_model, mels, tokens)
|
||||
|
||||
self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))
|
||||
|
||||
def test_decode_lang(self):
|
||||
options = decoding.DecodingOptions(task="lang_id")
|
||||
result = decoding.decode(self.model, self.mels, options)
|
||||
self.assertEqual(result.language, "en")
|
||||
self.assertEqual(len(result.language_probs), 99)
|
||||
self.assertAlmostEqual(
|
||||
result.language_probs["en"], 0.9947282671928406, places=5
|
||||
)
|
||||
|
||||
def test_decode_greedy(self):
|
||||
result = decoding.decode(self.model, self.mels)
|
||||
self.assertEqual(result.language, "en")
|
||||
self.assertEqual(
|
||||
result.tokens,
|
||||
[
|
||||
50364,
|
||||
1396,
|
||||
264,
|
||||
665,
|
||||
5133,
|
||||
23109,
|
||||
25462,
|
||||
264,
|
||||
6582,
|
||||
293,
|
||||
750,
|
||||
632,
|
||||
42841,
|
||||
292,
|
||||
370,
|
||||
938,
|
||||
294,
|
||||
4054,
|
||||
293,
|
||||
12653,
|
||||
356,
|
||||
50620,
|
||||
50620,
|
||||
23563,
|
||||
322,
|
||||
3312,
|
||||
13,
|
||||
50680,
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
result.text,
|
||||
(
|
||||
"Then the good soul openly sorted the boat and she "
|
||||
"had buoyed so long in secret and bravely stretched on alone."
|
||||
),
|
||||
)
|
||||
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
|
||||
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
# Small temp should give the same results
|
||||
result = decoding.decode(self.model, self.mels, temperature=1e-8)
|
||||
|
||||
self.assertEqual(
|
||||
result.text,
|
||||
(
|
||||
"Then the good soul openly sorted the boat and she "
|
||||
"had buoyed so long in secret and bravely stretched on alone."
|
||||
),
|
||||
)
|
||||
self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
|
||||
self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
def test_transcribe(self):
|
||||
result = whisper.transcribe(TEST_AUDIO)
|
||||
self.assertEqual(
|
||||
result["text"],
|
||||
(
|
||||
" Then the good soul openly sorted the boat and she "
|
||||
"had buoyed so long in secret and bravely stretched on alone."
|
||||
),
|
||||
)
|
||||
|
||||
def test_transcribe_alice(self):
|
||||
audio_file = os.path.join(
|
||||
os.path.expanduser("~"),
|
||||
".cache/whisper/alice.mp3",
|
||||
)
|
||||
if not os.path.exists(audio_file):
|
||||
print("To run this test download the alice in wonderland audiobook:")
|
||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||
return
|
||||
|
||||
result = whisper.transcribe(audio_file)
|
||||
self.assertEqual(len(result["text"]), 10920)
|
||||
self.assertEqual(result["language"], "en")
|
||||
self.assertEqual(len(result["segments"]), 77)
|
||||
|
||||
expected_5 = {
|
||||
"id": 5,
|
||||
"seek": 2800,
|
||||
"start": 40.0,
|
||||
"end": 46.0,
|
||||
"text": " Oh my poor little feet, I wonder who will put on your shoes and stockings for you now tears.",
|
||||
"tokens": [
|
||||
50964,
|
||||
876,
|
||||
452,
|
||||
4716,
|
||||
707,
|
||||
3521,
|
||||
11,
|
||||
286,
|
||||
2441,
|
||||
567,
|
||||
486,
|
||||
829,
|
||||
322,
|
||||
428,
|
||||
6654,
|
||||
293,
|
||||
4127,
|
||||
1109,
|
||||
337,
|
||||
291,
|
||||
586,
|
||||
10462,
|
||||
13,
|
||||
51264,
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"avg_logprob": -0.19670599699020386,
|
||||
"compression_ratio": 1.5991379310344827,
|
||||
"no_speech_prob": 0.09746722131967545,
|
||||
}
|
||||
|
||||
expected_73 = {
|
||||
"id": 73,
|
||||
"seek": 70700,
|
||||
"start": 707.0,
|
||||
"end": 715.0,
|
||||
"text": " let us get to the shore, and then I'll tell you my history, and you'll understand why it is that I hate cats and dogs.",
|
||||
"tokens": [
|
||||
50364,
|
||||
718,
|
||||
505,
|
||||
483,
|
||||
281,
|
||||
264,
|
||||
17805,
|
||||
11,
|
||||
293,
|
||||
550,
|
||||
286,
|
||||
603,
|
||||
980,
|
||||
291,
|
||||
452,
|
||||
2503,
|
||||
11,
|
||||
293,
|
||||
291,
|
||||
603,
|
||||
1223,
|
||||
983,
|
||||
309,
|
||||
307,
|
||||
300,
|
||||
286,
|
||||
4700,
|
||||
11111,
|
||||
293,
|
||||
7197,
|
||||
13,
|
||||
50764,
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"avg_logprob": -0.1350895343440594,
|
||||
"compression_ratio": 1.6208333333333333,
|
||||
"no_speech_prob": 0.002246702555567026,
|
||||
}
|
||||
|
||||
def check_segment(seg, expected):
|
||||
for k, v in expected.items():
|
||||
if isinstance(v, float):
|
||||
self.assertAlmostEqual(seg[k], v, places=3)
|
||||
else:
|
||||
self.assertEqual(seg[k], v)
|
||||
|
||||
# Randomly check a couple of segments
|
||||
check_segment(result["segments"][5], expected_5)
|
||||
check_segment(result["segments"][73], expected_73)
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def test_load(self):
|
||||
data = audio.load_audio(TEST_AUDIO)
|
||||
data_8k = audio.load_audio(TEST_AUDIO, 8000)
|
||||
n = 106640
|
||||
self.assertTrue(data.shape, (n,))
|
||||
self.assertTrue(data.dtype, np.float32)
|
||||
self.assertTrue(data_8k.shape, (n // 2,))
|
||||
|
||||
def test_pad(self):
|
||||
data = audio.load_audio(TEST_AUDIO)
|
||||
data = audio.pad_or_trim(data, 20_000)
|
||||
self.assertTrue(data.shape, [20_000])
|
||||
|
||||
def test_mel_spec(self):
|
||||
mels = audio.log_mel_spectrogram(TEST_AUDIO)
|
||||
self.assertTrue(mels.shape, [80, 400])
|
||||
self.assertTrue(mels.dtype, mx.float32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
4
whisper/whisper/__init__.py
Normal file
4
whisper/whisper/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from . import load_models
|
||||
from . import audio
|
||||
from . import decoding
|
||||
from .transcribe import transcribe
|
10
whisper/whisper/assets/download_alice.sh
Normal file
10
whisper/whisper/assets/download_alice.sh
Normal file
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
audio_file=$HOME/.cache/whisper/alice.mp3
|
||||
echo $audio_file
|
||||
zipf=alice_in_wonderland_librivox_64kb_mp3.zip
|
||||
url=https://www.archive.org/download/alice_in_wonderland_librivox/
|
||||
curl -LO $url/$zipf
|
||||
unzip $zipf
|
||||
mv wonderland_ch_02_64kb.mp3 $audio_file
|
||||
rm wonderland_* $zipf
|
50256
whisper/whisper/assets/gpt2.tiktoken
Normal file
50256
whisper/whisper/assets/gpt2.tiktoken
Normal file
File diff suppressed because it is too large
Load Diff
BIN
whisper/whisper/assets/ls_test.flac
Normal file
BIN
whisper/whisper/assets/ls_test.flac
Normal file
Binary file not shown.
BIN
whisper/whisper/assets/mel_filters.npz
Normal file
BIN
whisper/whisper/assets/mel_filters.npz
Normal file
Binary file not shown.
50257
whisper/whisper/assets/multilingual.tiktoken
Normal file
50257
whisper/whisper/assets/multilingual.tiktoken
Normal file
File diff suppressed because it is too large
Load Diff
173
whisper/whisper/audio.py
Normal file
173
whisper/whisper/audio.py
Normal file
@ -0,0 +1,173 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from subprocess import CalledProcessError, run
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = N_SAMPLES // HOP_LENGTH # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
|
||||
sr: int
|
||||
The sample rate to resample the audio if necessary
|
||||
|
||||
Returns
|
||||
-------
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||
"""
|
||||
if array.shape[axis] > length:
|
||||
sl = [slice(None)] * array.ndim
|
||||
sl[axis] = slice(0, length)
|
||||
array = array[tuple(sl)]
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
pad_fn = mx.pad if isinstance(array, mx.array) else np.pad
|
||||
array = pad_fn(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(n_mels: int = N_MELS) -> mx.array:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
|
||||
filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
return mx.load(filename)[f"mel_{n_mels}"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def hanning(size):
|
||||
return mx.array(np.hanning(size + 1)[:-1])
|
||||
|
||||
|
||||
def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="reflect"):
|
||||
if nfft is None:
|
||||
nfft = nperseg
|
||||
if noverlap is None:
|
||||
noverlap = nfft // 4
|
||||
|
||||
def _pad(x, padding, pad_mode="constant"):
|
||||
if pad_mode == "constant":
|
||||
return mx.pad(x, [(padding, padding)])
|
||||
elif pad_mode == "reflect":
|
||||
prefix = x[1 : padding + 1][::-1]
|
||||
suffix = x[-(padding + 1) : -1][::-1]
|
||||
return mx.concatenate([prefix, x, suffix])
|
||||
else:
|
||||
raise ValueError(f"Invalid pad_mode {pad_mode}")
|
||||
|
||||
padding = nperseg // 2
|
||||
x = _pad(x, padding, pad_mode)
|
||||
|
||||
strides = [noverlap, 1]
|
||||
t = (x.size - nperseg + noverlap) // noverlap
|
||||
shape = [t, nfft]
|
||||
x = mx.as_strided(x, shape=shape, strides=strides)
|
||||
return mx.fft.rfft(x * window)
|
||||
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray],
|
||||
n_mels: int = N_MELS,
|
||||
padding: int = 0,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, mx.array], shape = (*)
|
||||
The path to audio or either a NumPy or mlx array containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
Returns
|
||||
-------
|
||||
mx.array, shape = (80, n_frames)
|
||||
An array that contains the Mel spectrogram
|
||||
"""
|
||||
device = mx.default_device()
|
||||
mx.set_default_device(mx.cpu)
|
||||
if not isinstance(audio, mx.array):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = mx.array(audio)
|
||||
|
||||
if padding > 0:
|
||||
audio = mx.pad(audio, (0, padding))
|
||||
window = hanning(N_FFT)
|
||||
freqs = stft(audio, window, nperseg=N_FFT, noverlap=HOP_LENGTH)
|
||||
magnitudes = freqs[:-1, :].abs().square()
|
||||
|
||||
filters = mel_filters(n_mels)
|
||||
mel_spec = magnitudes @ filters.T
|
||||
|
||||
log_spec = mx.maximum(mel_spec, 1e-10).log10()
|
||||
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
mx.set_default_device(device)
|
||||
return log_spec
|
718
whisper/whisper/decoding.py
Normal file
718
whisper/whisper/decoding.py
Normal file
@ -0,0 +1,718 @@
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import zlib
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def detect_language(
|
||||
model: "Whisper", mel: mx.array, tokenizer: Tokenizer = None
|
||||
) -> Tuple[mx.array, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||
|
||||
Returns
|
||||
-------
|
||||
language_tokens : mx.array, shape = (n_audio,)
|
||||
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||
language_probs : List[Dict[str, float]], length = n_audio
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
):
|
||||
raise ValueError(
|
||||
"This model doesn't have language tokens so it can't perform lang id"
|
||||
)
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel[None]
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != [model.dims.n_audio_ctx, model.dims.n_audio_state]:
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = mel.shape[0]
|
||||
x = mx.array([[tokenizer.sot]] * n_audio) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32)
|
||||
mask[list(tokenizer.all_language_tokens)] = 0.0
|
||||
logits += mx.array(mask)
|
||||
language_tokens = mx.argmax(logits, axis=-1)
|
||||
language_token_probs = mx.softmax(logits, axis=-1)
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
return language_tokens, language_probs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
# whether to perform X->X "transcribe" or X->English "translate"
|
||||
task: str = "transcribe"
|
||||
|
||||
# language that the audio is in; uses detected language if None
|
||||
language: Optional[str] = None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||
|
||||
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||
# to select which to return among the beams or best-of-N samples
|
||||
length_penalty: Optional[float] = None
|
||||
|
||||
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0
|
||||
|
||||
# implementation details
|
||||
fp16: bool = False # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
audio_features: mx.array
|
||||
language: str
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
temperature: float = np.nan
|
||||
compression_ratio: float = np.nan
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = None
|
||||
|
||||
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
logits, self.kv_cache = self.model.decoder(
|
||||
tokens, audio_features, kv_cache=self.kv_cache
|
||||
)
|
||||
return logits
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
"""Update the key-value cache according to the updated beams"""
|
||||
# update the key/value cache to contain the selected sequences
|
||||
if source_indices != list(range(len(source_indices))):
|
||||
self.kv_cache = tree_map(lambda x: x[source_indices], self.kv_cache)
|
||||
|
||||
def reset(self):
|
||||
self.kv_cache = None
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(
|
||||
self, tokens: List[List[mx.array]], sum_logprobs: List[List[float]]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaximumLikelihoodRanker(SequenceRanker):
|
||||
"""
|
||||
Select the sample with the highest log probabilities, penalized using either
|
||||
a simple length normalization or Google NMT paper's length penalty
|
||||
"""
|
||||
|
||||
def __init__(self, length_penalty: Optional[float]):
|
||||
self.length_penalty = length_penalty
|
||||
|
||||
def rank(self, tokens: List[List[List[int]]], sum_logprobs: List[List[float]]):
|
||||
def scores(logprobs, lengths):
|
||||
result = []
|
||||
for logprob, length in zip(logprobs, lengths):
|
||||
if self.length_penalty is None:
|
||||
penalty = length
|
||||
else:
|
||||
# from the Google NMT paper
|
||||
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||
result.append(logprob / penalty)
|
||||
return result
|
||||
|
||||
# get the sequence with the highest score
|
||||
lengths = [[len(t) for t in s] for s in tokens]
|
||||
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||
|
||||
|
||||
class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(
|
||||
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||
) -> Tuple[mx.array, bool, mx.array]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : mx.array, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
logits : mx.array, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
sum_logprobs : mx.array, shape = (n_batch)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : mx.array, shape = (n_batch, current_sequence_length + 1)
|
||||
the tokens, appended with the selected next token
|
||||
|
||||
completed : bool
|
||||
True if all sequences has reached the end of text
|
||||
|
||||
sum_logprobs: mx.array, shape = (n_batch)
|
||||
updated cumulative log probabilities for each sequence
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: mx.array, sum_logprobs: mx.array
|
||||
) -> Tuple[Sequence[Sequence[mx.array]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : mx.array, shape = (n_audio, n_group, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence
|
||||
|
||||
sum_logprobs : mx.array, shape = (n_audio, n_group)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Sequence[Sequence[mx.array]], length = n_audio
|
||||
sequence of mx.arrays containing candidate token sequences, for each audio input
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(
|
||||
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
||||
) -> Tuple[mx.array, bool, mx.array]:
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(axis=-1)
|
||||
else:
|
||||
next_tokens = mx.random.categorical(logits=logits / self.temperature)
|
||||
|
||||
next_tokens = mx.argmax(logits, axis=-1)
|
||||
logits = logits.astype(mx.float32)
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
||||
|
||||
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
eot_mask = tokens[:, -1] == self.eot
|
||||
next_tokens = next_tokens * (1 - eot_mask) + self.eot * eot_mask
|
||||
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=-1)
|
||||
|
||||
completed = mx.all(tokens[:, -1] == self.eot)
|
||||
return tokens, completed, sum_logprobs
|
||||
|
||||
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
||||
"""Apply any filtering or masking to logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : mx.array, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
tokens : mx.array, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SuppressBlank(LogitFilter):
|
||||
def __init__(self, tokenizer: Tokenizer, sample_begin: int, n_vocab: int):
|
||||
self.sample_begin = sample_begin
|
||||
mask = np.zeros(n_vocab, np.float32)
|
||||
mask[tokenizer.encode(" ") + [tokenizer.eot]] = -np.inf
|
||||
self.mask = mx.array(mask)
|
||||
|
||||
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
return logits + self.mask
|
||||
return logits
|
||||
|
||||
|
||||
class SuppressTokens(LogitFilter):
|
||||
def __init__(self, suppress_tokens: Sequence[int], n_vocab: int):
|
||||
mask = np.zeros(n_vocab, np.float32)
|
||||
mask[list(suppress_tokens)] = -np.inf
|
||||
self.mask = mx.array(mask)
|
||||
|
||||
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
||||
return logits + self.mask
|
||||
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
sample_begin: int,
|
||||
max_initial_timestamp_index: Optional[int],
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||
|
||||
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
||||
mask = np.zeros(logits.shape, np.float32)
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
mask[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = sampled_tokens.tolist()
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
penultimate_was_timestamp = (
|
||||
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
mask[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||
else: # cannot be normal text tokens
|
||||
mask[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
timestamps = [
|
||||
i for i, v in enumerate(seq) if v > self.tokenizer.timestamp_begin
|
||||
]
|
||||
if len(timestamps) > 0:
|
||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
# also force each segment to have a nonzero length, to prevent infinite looping
|
||||
last_timestamp = timestamps[-1]
|
||||
if not last_timestamp or penultimate_was_timestamp:
|
||||
last_timestamp += 1
|
||||
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = (
|
||||
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
)
|
||||
mask[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
axis=-1
|
||||
)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
mask[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
return logits + mx.array(mask, logits.dtype)
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
inference: Inference
|
||||
sequence_ranker: SequenceRanker
|
||||
decoder: TokenDecoder
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, language=language, task=options.task
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||
|
||||
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||
if self.options.without_timestamps:
|
||||
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||
|
||||
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||
self.sample_begin: int = len(self.initial_tokens)
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = Inference(model, len(self.initial_tokens))
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
raise NotImplementedError("Beam search decoder is not yet implemented")
|
||||
# self.decoder = BeamSearchDecoder(
|
||||
# options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
# )
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||
self.logit_filters = []
|
||||
if self.options.suppress_blank:
|
||||
self.logit_filters.append(
|
||||
SuppressBlank(self.tokenizer, self.sample_begin, model.dims.n_vocab)
|
||||
)
|
||||
if self.options.suppress_tokens:
|
||||
self.logit_filters.append(
|
||||
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
|
||||
)
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(
|
||||
self.options.max_initial_timestamp / precision
|
||||
)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(
|
||||
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||
)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
if options.beam_size is not None and options.best_of is not None:
|
||||
raise ValueError("beam_size and best_of can't be given together")
|
||||
if options.temperature == 0:
|
||||
if options.best_of is not None:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (
|
||||
0 <= options.length_penalty <= 1
|
||||
):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
|
||||
if prefix := self.options.prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip())
|
||||
if isinstance(prefix, str)
|
||||
else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt := self.options.prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip())
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
suppress_tokens = self.options.suppress_tokens
|
||||
|
||||
if isinstance(suppress_tokens, str):
|
||||
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
def _get_audio_features(self, mel: mx.array):
|
||||
if self.options.fp16:
|
||||
mel = mel.astype(mx.float16)
|
||||
|
||||
if mel.shape[-2:] == (
|
||||
self.model.dims.n_audio_ctx,
|
||||
self.model.dims.n_audio_state,
|
||||
):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
audio_features = self.model.encoder(mel)
|
||||
|
||||
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
)
|
||||
|
||||
return audio_features
|
||||
|
||||
def _detect_language(self, audio_features: mx.array, tokens: np.array):
|
||||
languages = [self.options.language] * audio_features.shape[0]
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(
|
||||
audio_features, self.tokenizer
|
||||
)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
# write language tokens
|
||||
tokens[:, self.sot_index + 1] = np.array(lang_tokens)
|
||||
|
||||
return languages, lang_probs
|
||||
|
||||
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: mx.array = mx.zeros(n_batch)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
try:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
probs_at_sot = mx.softmax(
|
||||
logits[:, self.sot_index].astype(mx.float32), axis=-1
|
||||
)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
logits = logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed, sum_logprobs = self.decoder.update(
|
||||
tokens, logits, sum_logprobs
|
||||
)
|
||||
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
self.inference.reset()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
|
||||
def run(self, mel: mx.array) -> List[DecodingResult]:
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
||||
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
|
||||
tokens: np.array = np.array(self.initial_tokens)
|
||||
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy()
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features, language=language, language_probs=probs
|
||||
)
|
||||
for features, language, probs in zip(
|
||||
audio_features, languages, language_probs
|
||||
)
|
||||
]
|
||||
|
||||
# repeat tokens by the group size, for beam search or best-of-n sampling
|
||||
tokens = mx.array(tokens)
|
||||
if self.n_group > 1:
|
||||
tokens = tokens[:, None, :]
|
||||
tokens = mx.broadcast_to(
|
||||
tokens, [n_audio, self.n_group, len(self.initial_tokens)]
|
||||
)
|
||||
tokens = tokens.reshape(
|
||||
tokens, (n_audio * self.n_group, len(self.initial_tokens))
|
||||
)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens = tokens[..., self.sample_begin :].tolist()
|
||||
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i] for i, t in zip(selected, tokens)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [
|
||||
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||
]
|
||||
|
||||
fields = (
|
||||
texts,
|
||||
languages,
|
||||
tokens,
|
||||
audio_features,
|
||||
avg_logprobs,
|
||||
no_speech_probs,
|
||||
)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features,
|
||||
language=language,
|
||||
tokens=tokens,
|
||||
text=text,
|
||||
avg_logprob=avg_logprob,
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||
*fields
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def decode(
|
||||
model: "Whisper",
|
||||
mel: mx.array,
|
||||
options: DecodingOptions = DecodingOptions(),
|
||||
**kwargs,
|
||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
the Whisper model instance
|
||||
|
||||
mel: mx.array, shape = (80, 3000) or (*, 80, 3000)
|
||||
An array containing the Mel spectrogram(s)
|
||||
|
||||
options: DecodingOptions
|
||||
A dataclass that contains all necessary options for decoding 30-second segments
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
if single := mel.ndim == 2:
|
||||
mel = mel[None]
|
||||
|
||||
if kwargs:
|
||||
options = replace(options, **kwargs)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
return result[0] if single else result
|
192
whisper/whisper/load_models.py
Normal file
192
whisper/whisper/load_models.py
Normal file
@ -0,0 +1,192 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import whisper
|
||||
from . import torch_whisper
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str) -> str:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_torch_model(
|
||||
name: str,
|
||||
download_root: str = None,
|
||||
) -> torch_whisper.Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if download_root is None:
|
||||
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
with open(checkpoint_file, "rb") as fp:
|
||||
checkpoint = torch.load(fp)
|
||||
|
||||
dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
|
||||
model = torch_whisper.Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert(model, rules=None):
|
||||
params = {}
|
||||
if rules is not None and type(model) in rules:
|
||||
out = rules[type(model)](model, rules)
|
||||
return out
|
||||
if isinstance(model, torch.Tensor):
|
||||
return mx.array(model.detach().numpy())
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
return [convert(n, rules) for n in model.children()]
|
||||
if isinstance(model, torch.nn.Conv1d):
|
||||
return {
|
||||
"weight": convert(model.weight).transpose(0, 2, 1),
|
||||
"bias": convert(model.bias),
|
||||
}
|
||||
for k, n in model.named_children():
|
||||
if k in rules:
|
||||
params.update(rules[k](n, rules))
|
||||
else:
|
||||
params[k] = convert(n, rules)
|
||||
for k, p in model.named_parameters(recurse=False):
|
||||
params[k] = convert(p)
|
||||
return params
|
||||
|
||||
|
||||
def torch_to_mlx(
|
||||
torch_model: torch_whisper.Whisper,
|
||||
) -> whisper.Whisper:
|
||||
def convert_rblock(model, rules):
|
||||
children = dict(model.named_children())
|
||||
mlp = list(children.pop("mlp").children())
|
||||
params = {
|
||||
"mlp1": convert(mlp[0], rules),
|
||||
"mlp2": convert(mlp[-1], rules),
|
||||
}
|
||||
for k, n in children.items():
|
||||
params[k] = convert(n, rules)
|
||||
return params
|
||||
|
||||
rules = {
|
||||
torch_whisper.ResidualAttentionBlock: convert_rblock,
|
||||
}
|
||||
|
||||
params = convert(torch_model, rules)
|
||||
|
||||
mlx_model = whisper.Whisper(torch_model.dims)
|
||||
mlx_model.update(params)
|
||||
return mlx_model
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
download_root: str = None,
|
||||
) -> whisper.Whisper:
|
||||
return torch_to_mlx(load_torch_model(name, download_root))
|
385
whisper/whisper/timing.py
Normal file
385
whisper/whisper/timing.py
Normal file
@ -0,0 +1,385 @@
|
||||
import itertools
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
# F.pad requires the padding width to be smaller than the input dimension
|
||||
return x
|
||||
|
||||
if (ndim := x.ndim) <= 2:
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert (
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower median kernel implementation..."
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
result = []
|
||||
while i > 0 or j > 0:
|
||||
result.append((i - 1, j - 1))
|
||||
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise ValueError("Unexpected trace[i, j]")
|
||||
|
||||
result = np.array(result)
|
||||
return result[::-1, :].T
|
||||
|
||||
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
for i in range(1, N + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = x[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower DTW implementation..."
|
||||
)
|
||||
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTiming:
|
||||
word: str
|
||||
tokens: List[int]
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
if len(word_tokens) <= 1:
|
||||
# return on eot only
|
||||
# >>> np.pad([], (1, 0))
|
||||
# array([0.])
|
||||
# This results in crashes when we lookup jump_times with float, like
|
||||
# IndexError: arrays used as indices must be of integer (or boolean) type
|
||||
return []
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
return [
|
||||
WordTiming(word, tokens, start, end, probability)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following.word = previous.word + following.word
|
||||
following.tokens = previous.tokens + following.tokens
|
||||
previous.word = ""
|
||||
previous.tokens = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous.word.endswith(" ") and following.word in appended:
|
||||
# append it to the previous word
|
||||
previous.word = previous.word + following.word
|
||||
previous.tokens = previous.tokens + following.tokens
|
||||
following.word = ""
|
||||
following.tokens = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
|
||||
def add_word_timestamps(
|
||||
*,
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
last_speech_timestamp: float,
|
||||
**kwargs,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens_per_segment = [
|
||||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||
word_durations = np.array([t.end - t.start for t in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if len(word_durations) > 0:
|
||||
sentence_end_marks = ".。!!??"
|
||||
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
||||
for i in range(1, len(alignment)):
|
||||
if alignment[i].end - alignment[i].start > max_duration:
|
||||
if alignment[i].word in sentence_end_marks:
|
||||
alignment[i].end = alignment[i].start + max_duration
|
||||
elif alignment[i - 1].word in sentence_end_marks:
|
||||
alignment[i].start = alignment[i].end - max_duration
|
||||
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
word_index = 0
|
||||
|
||||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||
saved_tokens = 0
|
||||
words = []
|
||||
|
||||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||
timing = alignment[word_index]
|
||||
|
||||
if timing.word:
|
||||
words.append(
|
||||
dict(
|
||||
word=timing.word,
|
||||
start=round(time_offset + timing.start, 2),
|
||||
end=round(time_offset + timing.end, 2),
|
||||
probability=timing.probability,
|
||||
)
|
||||
)
|
||||
|
||||
saved_tokens += len(timing.tokens)
|
||||
word_index += 1
|
||||
|
||||
# hack: truncate long words at segment boundaries.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if len(words) > 0:
|
||||
# ensure the first and second word after a pause is not longer than
|
||||
# twice the median word duration.
|
||||
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
||||
words[0]["end"] - words[0]["start"] > max_duration
|
||||
or (
|
||||
len(words) > 1
|
||||
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
||||
)
|
||||
):
|
||||
if (
|
||||
len(words) > 1
|
||||
and words[1]["end"] - words[1]["start"] > max_duration
|
||||
):
|
||||
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
||||
words[0]["end"] = words[1]["start"] = boundary
|
||||
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
||||
|
||||
# prefer the segment-level start timestamp if the first word is too long.
|
||||
if (
|
||||
segment["start"] < words[0]["end"]
|
||||
and segment["start"] - 0.5 > words[0]["start"]
|
||||
):
|
||||
words[0]["start"] = max(
|
||||
0, min(words[0]["end"] - median_duration, segment["start"])
|
||||
)
|
||||
else:
|
||||
segment["start"] = words[0]["start"]
|
||||
|
||||
# prefer the segment-level end timestamp if the last word is too long.
|
||||
if (
|
||||
segment["end"] > words[-1]["start"]
|
||||
and segment["end"] + 0.5 < words[-1]["end"]
|
||||
):
|
||||
words[-1]["end"] = max(
|
||||
words[-1]["start"] + median_duration, segment["end"]
|
||||
)
|
||||
else:
|
||||
segment["end"] = words[-1]["end"]
|
||||
|
||||
last_speech_timestamp = segment["end"]
|
||||
|
||||
segment["words"] = words
|
387
whisper/whisper/tokenizer.py
Normal file
387
whisper/whisper/tokenizer.py
Normal file
@ -0,0 +1,387 @@
|
||||
import base64
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tiktoken
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
encoding: tiktoken.Encoding
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for special in self.encoding.special_tokens_set:
|
||||
special_token = self.encoding.encode_single_token(special)
|
||||
self.special_tokens[special] = special_token
|
||||
|
||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
if self.task is not None:
|
||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||
sot_sequence.append(task_token)
|
||||
|
||||
self.sot_sequence = tuple(sot_sequence)
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.encoding.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.encoding.eot_token
|
||||
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self.special_tokens["<|transcribe|>"]
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self.special_tokens["<|translate|>"]
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self.special_tokens["<|startoftranscript|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self.special_tokens["<|startoflm|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self.special_tokens["<|startofprev|>"]
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self.special_tokens["<|nospeech|>"]
|
||||
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self.special_tokens["<|notimestamps|>"]
|
||||
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.special_tokens["<|0.00|>"]
|
||||
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
if token := self.special_tokens.get(f"<|{self.language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
||||
|
||||
@cached_property
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [
|
||||
self.encoding.encode(symbol),
|
||||
self.encoding.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
decoded_full = self.decode_with_timestamps(tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||
== replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2"):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
with open(vocab_path) as fid:
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in fid if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
encoding_name = "multilingual"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
|
||||
encoding = get_encoding(name=encoding_name)
|
||||
|
||||
return Tokenizer(encoding=encoding, language=language, task=task)
|
301
whisper/whisper/torch_whisper.py
Normal file
301
whisper/whisper/torch_whisper.py
Normal file
@ -0,0 +1,301 @@
|
||||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x,
|
||||
self.weight.to(x.dtype),
|
||||
None if self.bias is None else self.bias.to(x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half layers for alignment by default; see `set_alignment_heads()` below
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
358
whisper/whisper/transcribe.py
Normal file
358
whisper/whisper/transcribe.py
Normal file
@ -0,0 +1,358 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
import tqdm
|
||||
|
||||
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .load_models import load_model
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
|
||||
|
||||
def _format_timestamp(seconds: float):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if hours > 0 else ""
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
class ModelHolder:
|
||||
model = None
|
||||
model_name = None
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str):
|
||||
if cls.model is None or model != cls.model_name:
|
||||
cls.model = load_model(model)
|
||||
cls.model_name = model
|
||||
return cls.model
|
||||
|
||||
|
||||
def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array],
|
||||
*,
|
||||
model: str = "tiny",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, mx.array]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
model: str
|
||||
The Whisper model. Can be any of ["tiny", "base", "small", "medium", "large"].
|
||||
Default is "tiny".
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
|
||||
model = ModelHolder.get_model(model)
|
||||
|
||||
dtype = mx.float16 if decode_options.get("fp16", False) else mx.float32
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-2] - N_FRAMES
|
||||
|
||||
if verbose:
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
if system_encoding != "utf-8":
|
||||
make_safe = lambda x: x.encode(system_encoding, errors="replace").decode(
|
||||
system_encoding
|
||||
)
|
||||
else:
|
||||
make_safe = lambda x: x
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. "
|
||||
"Use the `language` decoding option to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES, axis=-2).astype(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
def decode_with_fallback(segment: mx.array) -> DecodingResult:
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
kwargs = {**decode_options}
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
needs_fallback = False
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
if (
|
||||
no_speech_threshold is not None
|
||||
and decode_result.no_speech_prob > no_speech_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
break
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: mx.array, result: DecodingResult
|
||||
):
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
while seek < content_frames:
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
mel_segment = mel[seek : seek + N_FRAMES]
|
||||
segment_size = min(N_FRAMES, content_frames - seek)
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = np.array(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
timestamp_tokens = tokens >= tokenizer.timestamp_begin
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
consecutive = np.where(
|
||||
np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:])
|
||||
)[0]
|
||||
consecutive += 1
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens))
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
seek += segment_size
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
214
whisper/whisper/whisper.py
Normal file
214
whisper/whisper/whisper.py
Normal file
@ -0,0 +1,214 @@
|
||||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2))
|
||||
scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :]
|
||||
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
xa=None,
|
||||
mask=None,
|
||||
kv_cache=None,
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if xa is None:
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
if kv_cache is not None:
|
||||
k = mx.concatenate([kv_cache[0], k], axis=1)
|
||||
v = mx.concatenate([kv_cache[1], v], axis=1)
|
||||
elif kv_cache is None:
|
||||
k = self.key(xa)
|
||||
v = self.value(xa)
|
||||
else:
|
||||
k, v = kv_cache
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), (k, v)
|
||||
|
||||
def qkv_attention(self, q, k, v, mask=None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale
|
||||
k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale
|
||||
v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
|
||||
qk = qk.astype(mx.float32)
|
||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
||||
out = (w @ v).transpose(0, 2, 1, 3)
|
||||
out = out.reshape(n_batch, n_ctx, n_state)
|
||||
return out
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp1 = nn.Linear(n_state, n_mlp)
|
||||
self.mlp2 = nn.Linear(n_mlp, n_state)
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
||||
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
||||
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
|
||||
x += y
|
||||
if self.cross_attn:
|
||||
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
|
||||
x += y
|
||||
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
|
||||
return x, (kv, cross_kv)
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self._positional_embedding = sinusoids(n_ctx, n_state)
|
||||
|
||||
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
self.ln_post = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.gelu(self.conv1(x))
|
||||
x = nn.gelu(self.conv2(x))
|
||||
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
|
||||
x = x + self._positional_embedding
|
||||
|
||||
for block in self.blocks:
|
||||
x, _ = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = mx.zeros((n_ctx, n_state))
|
||||
|
||||
self.blocks = [
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
self.ln = nn.LayerNorm(n_state)
|
||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx)
|
||||
|
||||
def __call__(self, x, xa, kv_cache=None):
|
||||
"""
|
||||
x : mx.array, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : mx.array, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = kv_cache[0][0][0].shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
|
||||
if kv_cache is None:
|
||||
kv_cache = [None] * len(self.blocks)
|
||||
for e, block in enumerate(self.blocks):
|
||||
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e])
|
||||
|
||||
x = self.ln(x)
|
||||
return x @ self.token_embedding.weight.T, kv_cache
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
|
||||
def embed_audio(self, mel):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens, audio_features):
|
||||
return self.decoder(tokens, audio_features)[0]
|
||||
|
||||
def __call__(self, mel, tokens):
|
||||
return self.decoder(tokens, self.encoder(mel))[0]
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
|
||||
detect_language = detect_language_function
|
||||
decode = decode_function
|
Loading…
Reference in New Issue
Block a user