Support Hugging Face models (#215)

* support hf direct models
This commit is contained in:
Awni Hannun 2024-01-03 15:13:26 -08:00 committed by GitHub
parent 1d09c4fecd
commit a5d6d0436c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 654 additions and 27 deletions

View File

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 22.10.0
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@ -32,7 +32,6 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
def main(args):
# Data loading
x, y, adj = load_data(args)
train_mask, val_mask, test_mask = train_val_test_mask()
@ -55,7 +54,6 @@ def main(args):
# Training loop
for epoch in range(args.epochs):
# Loss
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
@ -96,7 +94,6 @@ def main(args):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")

View File

@ -44,7 +44,9 @@ def convert(args):
config = model.config.to_dict()
state_dict = model.state_dict()
tokenizer = AutoTokenizer.from_pretrained(str(hf_path), trust_remote_code=True, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
str(hf_path), trust_remote_code=True, use_fast=False
)
# things to change
# 1. there's no "model." in the weight names
@ -84,7 +86,9 @@ def convert(args):
weights = {k: v.numpy() for k, v in state_dict.items()}
config["rope_scaling_factor"] = config["rope_scaling"]["factor"] if config["rope_scaling"] is not None else 1.0
config["rope_scaling_factor"] = (
config["rope_scaling"]["factor"] if config["rope_scaling"] is not None else 1.0
)
keep_keys = set(
[
"vocab_size",
@ -96,7 +100,7 @@ def convert(args):
"rms_norm_eps",
"intermediate_size",
"rope_scaling_factor",
"rope_theta"
"rope_theta",
]
)
for k in list(config.keys()):

View File

@ -285,7 +285,11 @@ if __name__ == "__main__":
model, tokenizer = load_model(args.model_path)
prompt = tokenizer(args.prompt, return_tensors="np", return_attention_mask=False,)[
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]

75
llms/hf_llm/README.md Normal file
View File

@ -0,0 +1,75 @@
## Generate Text with MLX and :hugs: Hugging Face
This an example large language model text generation that can pull models from
the Hugging Face Hub.
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
### Run
```
python generate.py --model <model_path> --prompt "hello"
```
For example:
```
python generate.py --model mistralai/Mistral-7B-v0.1 --prompt "hello"
```
will download the Mistral 7B model and generate text using the given prompt.
The `<model_path>` should be either a path to a local directory or a Hugging
Face repo with weights stored in `safetensors` format. If you use a repo from
the Hugging Face Hub, then the model will be downloaded and cached the first
time you run it. See the [Models](#models) section for a full list of supported models.
Run `python generate.py --help` to see all the options.
### Models
The example supports Hugging Face format Mistral and Llama-style models. If the
model you want to run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
Here are a few examples of Hugging Face models which work with this example:
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending)
and
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending)
style models should work out of the box.
### Convert new models
You can convert (change the data type or quantize) models using the
`convert.py` script. This script takes a Hugging Face repo as input and outputs
a model directory (which you can optionally also upload to Hugging Face).
For example, to make 4-bit quantized a model, run:
```
python convert.py --hf-model <hf_repo> -q
```
For more options run:
```
python convert.py --help
```
You can upload new models to the [Hugging Face MLX
Community](https://huggingface.co/mlx-community) by specifying `--upload-name``
to `convert.py`.

174
llms/hf_llm/convert.py Normal file
View File

@ -0,0 +1,174 @@
# Copyright © 2023 Apple Inc.
import argparse
import copy
import glob
import json
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import transformers
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from models import Model, ModelArgs
def fetch_from_hub(hf_path: str):
model_path = snapshot_download(
repo_id=hf_path,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
weight_files = glob.glob(f"{model_path}/*.safetensors")
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
config = transformers.AutoConfig.from_pretrained(hf_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
hf_path,
)
return weights, config.to_dict(), tokenizer
def quantize(weights, config, args):
quantized_config = copy.deepcopy(config)
# Load the model:
model = Model(ModelArgs.from_dict(config))
model.load_weights(list(weights.items()))
# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
# Update the config:
quantized_config["quantization"] = {
"group_size": args.q_group_size,
"bits": args.q_bits,
}
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
estimated_size = v.size * v.dtype.size
if shard_size + estimated_size > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += estimated_size
shards.append(shard)
return shards
def upload_to_hub(path: str, name: str):
import os
from huggingface_hub import HfApi, ModelCard, logging
repo_id = f"mlx-community/{name}"
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = f"""
# {name}
This model was converted to MLX format from [`{hf_path}`]().
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install mlx
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/llms/hf_llm
python generate.py --model {repo_id} --prompt "My name is"
```
"""
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=repo_id,
repo_type="model",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument(
"--hf-path",
type=str,
help="Path to the Hugging Face model.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"-q",
"--quantize",
help="Generate a quantized model.",
action="store_true",
)
parser.add_argument(
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
type=str,
default=None,
)
args = parser.parse_args()
print("[INFO] Loading")
weights, config, tokenizer = fetch_from_hub(args.hf_path)
if args.quantize:
print("[INFO] Quantizing")
weights, config = quantize(weights, config, args)
if not args.quantize:
dtype = getattr(mx, args.dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
for i, shard in enumerate(shards):
mx.save_safetensors(str(mlx_path / f"weights.{i:02d}.safetensors"), shard)
tokenizer.save_pretrained(mlx_path)
with open(mlx_path / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
if args.upload_name is not None:
upload_to_hub(mlx_path, args.upload_name)

86
llms/hf_llm/generate.py Normal file
View File

@ -0,0 +1,86 @@
# Copyright © 2023 Apple Inc.
import argparse
import time
import mlx.core as mx
import models
import transformers
def generate(
model: models.Model,
tokenizer: transformers.AutoTokenizer,
prompt: str,
max_tokens: int,
temp: float = 0.0,
):
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
][0]
prompt = mx.array(prompt)
tic = time.time()
tokens = []
skip = 0
for token, n in zip(
models.generate(prompt, model, args.temp),
range(args.max_tokens),
):
if token == tokenizer.eos_token_id:
break
if n == 0:
prompt_time = time.time() - tic
tic = time.time()
tokens.append(token.item())
# if (n + 1) % 10 == 0:
s = tokenizer.decode(tokens)
print(s[skip:], end="", flush=True)
skip = len(s)
print(tokenizer.decode(tokens)[skip:], flush=True)
gen_time = time.time() - tic
print("=" * 10)
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="inference script")
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="In the beginning the Universe was created.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = models.load(args.model)
generate(model, tokenizer, args.prompt, args.max_tokens, args.temp)

255
llms/hf_llm/models.py Normal file
View File

@ -0,0 +1,255 @@
# Copyright © 2023 Apple Inc.
import glob
import inspect
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
def repeat(a):
a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
return a.reshape([B, self.n_heads, L, -1])
if self.repeats > 1:
keys, values = map(repeat, (keys, values))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it
# otherwise download and cache from the hf_repo and cache
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=["*.json", "*.safetensors", "tokenizer.model"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
quantization = config.get("quantization", None)
model_args = ModelArgs.from_dict(config)
weight_files = glob.glob(str(model_path / "*.safetensors"))
if len(weight_files) == 0:
raise FileNotFoundError("No safetensors found in {}".format(model_path))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf).items())
model = Model(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
tokenizer = AutoTokenizer.from_pretrained(
model_path,
)
return model, tokenizer
def generate(prompt: mx.array, model: Model, temp: float = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
y = prompt
cache = None
while True:
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
y = sample(logits)
yield y

View File

@ -0,0 +1,3 @@
mlx>=0.0.7
numpy
transformers

View File

@ -1,9 +1,9 @@
# Copyright © 2023 Apple Inc.
import argparse
import glob
import json
import time
import glob
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

View File

@ -27,7 +27,11 @@ class Tokenizer:
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(s, return_tensors="np", return_attention_mask=False,)[
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)[
"input_ids"
].squeeze(0)
)

View File

@ -79,9 +79,13 @@ class StableDiffusion:
return x_t_prev
def _denoising_loop(self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5):
def _denoising_loop(
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
):
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, start_time=T, dtype=self.dtype):
for t, t_prev in self.sampler.timesteps(
num_steps, start_time=T, dtype=self.dtype
):
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
yield x_t
@ -100,7 +104,9 @@ class StableDiffusion:
mx.random.seed(seed)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
# Create the latent variables
x_T = self.sampler.sample_prior(
@ -108,7 +114,9 @@ class StableDiffusion:
)
# Perform the denoising loop
yield from self._denoising_loop(x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight)
yield from self._denoising_loop(
x_T, self.sampler.max_time, conditioning, num_steps, cfg_weight
)
def generate_latents_from_image(
self,
@ -130,16 +138,20 @@ class StableDiffusion:
num_steps = int(num_steps * strength)
# Get the text conditioning
conditioning = self._get_text_conditioning(text, n_images, cfg_weight, negative_text)
conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop
yield from self._denoising_loop(x_T, start_step, conditioning, num_steps, cfg_weight)
yield from self._denoising_loop(
x_T, start_step, conditioning, num_steps, cfg_weight
)
def decode(self, x_t):
x = self.autoencoder.decode(x_t)

View File

@ -381,7 +381,6 @@ class UNetModel(nn.Module):
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)

View File

@ -86,8 +86,13 @@ if __name__ == "__main__":
for model_name in models:
model_path = f"{args.mlx_dir}/{model_name}"
if not os.path.exists(model_path):
print(f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion")
subprocess.run(f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}", shell=True)
print(
f"\nDidn't find the MLX-format {model_name} model in the folder {args.mlx_dir}. Lauching conversion"
)
subprocess.run(
f"python convert.py --torch-name-or-path {model_name} --mlx-path {model_path}",
shell=True,
)
print(f"\nModel: {model_name.upper()}")
tokens = mx.array(

View File

@ -71,7 +71,9 @@ def _download(url: str, root: str) -> str:
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
@ -132,7 +134,9 @@ def load_torch_model(
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(f"Model {name_or_path} is neither found in {available_models()} nor as a local path")
raise RuntimeError(
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
)
with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)
@ -259,7 +263,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
assert args.dtype in _VALID_DTYPES, f"dtype {args.dtype} not found in {_VALID_DTYPES}"
assert (
args.dtype in _VALID_DTYPES
), f"dtype {args.dtype} not found in {_VALID_DTYPES}"
dtype = getattr(mx, args.dtype)
print("[INFO] Loading")

View File

@ -10,6 +10,7 @@ from pathlib import Path
import mlx.core as mx
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from mlx.utils import tree_flatten
import whisper
@ -17,8 +18,6 @@ import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
from convert import load_torch_model, quantize, torch_to_mlx
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
@ -189,7 +188,9 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False)
result = whisper.transcribe(
TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
result["text"],
(
@ -208,7 +209,9 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False)
result = whisper.transcribe(
audio_file, model_path=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
self.assertEqual(result["language"], "en")
self.assertEqual(len(result["segments"]), 77)