Merge branch 'main' into feat/batch_generate

This commit is contained in:
L Lllvvuu 2024-10-09 15:03:30 -04:00
commit 8fb82fee43
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F
45 changed files with 2695 additions and 298 deletions

View File

@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.

View File

@ -27,6 +27,7 @@ Some more useful examples are listed below.
### Audio Models
- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).
### Multimodal models

83
encodec/README.md Normal file
View File

@ -0,0 +1,83 @@
# EnCodec
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
generate audio.
### Setup
Install the requirements:
```
pip install -r requirements.txt
```
Optionally install FFmpeg and SciPy for loading and saving audio files,
respectively.
Install [FFmpeg](https://ffmpeg.org/):
```
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```
Install SciPy:
```
pip install scipy
```
### Example
An example using the model:
```python
import mlx.core as mx
from utils import load, load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
```
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
in several data types.
### Optional
To convert models, use the `convert.py` script. To see the options, run:
```bash
python convert.py -h
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
[code](https://github.com/facebookresearch/encodec) for more details.

View File

@ -0,0 +1,30 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
from utils import load
model, processor = load("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)
mx.eval(model, feats, mask)
@mx.compile
def fun():
codes, scales = model.encode(feats, mask, bandwidth=3)
reconstructed = model.decode(codes, scales, mask)
return reconstructed
for _ in range(5):
mx.eval(fun())
tic = time.time()
for _ in range(10):
mx.eval(fun())
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")

View File

@ -0,0 +1,34 @@
# Copyright © 2024 Apple Inc.
import time
import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
).to("mps")
def fun():
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
)
torch.mps.synchronize()
for _ in range(5):
fun()
tic = time.time()
for _ in range(10):
fun()
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")

213
encodec/convert.py Normal file
View File

@ -0,0 +1,213 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
import encodec
def fetch_from_hub(hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.json", "*.safetensors"],
)
)
return model_path
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
content = dedent(
f"""
---
language: en
license: other
library: mlx
tags:
- mlx
---
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from
[{hf_path}](https://huggingface.co/{hf_path}).
This model is intended to be used with the [EnCodec MLX
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
"""
)
card = ModelCard(content)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
mx.save_safetensors(
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
)
for weight_name in weights.keys():
index_data["weight_map"][weight_name] = "model.safetensors"
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(index_data, f, indent=4)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def convert(
upload: bool,
model: str,
dtype: str = None,
):
hf_repo = f"facebook/encodec_{model}"
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
path = fetch_from_hub(hf_repo)
save_path = Path("mlx_models")
weights = mx.load(str(Path(path) / "model.safetensors"))
with open(path / "config.json", "r") as fid:
config = SimpleNamespace(**json.load(fid))
model = encodec.EncodecModel(config)
new_weights = {}
for k, v in weights.items():
basename, pname = k.rsplit(".", 1)
if pname == "weight_v":
g = weights[basename + ".weight_g"]
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
k = basename + ".weight"
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
continue
elif "lstm" in basename:
w_or_b, ih_or_hh, ln = pname.split("_")
if w_or_b == "weight":
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
elif w_or_b == "bias" and ih_or_hh == "ih":
continue
else:
v = v + weights[k.replace("_hh_", "_ih_")]
new_pname = "bias"
k = basename + "." + ln[1:] + "." + new_pname
if "conv.weight" in k:
# Possibly a transposed conv which has a different order
if "decoder" in k:
ln = int(k.split(".")[2])
if "conv" in model.decoder.layers[ln] and isinstance(
model.decoder.layers[ln].conv, nn.ConvTranspose1d
):
v = mx.moveaxis(v, 0, 2)
else:
v = mx.moveaxis(v, 1, 2)
else:
v = mx.moveaxis(v, 1, 2)
new_weights[k] = v
weights = new_weights
model.load_weights(list(weights.items()))
if dtype is not None:
t = getattr(mx, dtype)
weights = {k: v.astype(t) for k, v in weights.items()}
if isinstance(save_path, str):
save_path = Path(save_path)
save_weights(save_path, weights)
save_config(vars(config), config_path=save_path / "config.json")
if upload:
upload_to_hub(save_path, mlx_repo, hf_repo)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
parser.add_argument(
"--model",
type=str,
default="48khz",
help="",
choices=["24khz", "32khz", "48khz"],
)
parser.add_argument(
"--upload",
action="store_true",
help="Upload the weights to Hugging Face.",
)
parser.add_argument(
"--dtype",
type=str,
help="Data type to convert the model to.",
default="float32",
choices=["float32", "bfloat16", "float16"],
)
args = parser.parse_args()
convert(upload=args.upload, model=args.model, dtype=args.dtype)

671
encodec/encodec.py Normal file
View File

@ -0,0 +1,671 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
_lstm_kernel = mx.fast.metal_kernel(
name="lstm",
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
output_names=["hidden_state", "cell_state"],
header="""
template <typename T>
T sigmoid(T x) {
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
return (x < 0) ? 1 - y : y;
}
""",
source="""
uint b = thread_position_in_grid.x;
uint d = hidden_size * 4;
uint elem = b * d + thread_position_in_grid.y;
uint index = elem;
uint x_index = b * num_time_steps * d + time_step * d + index;
auto i = sigmoid(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto f = sigmoid(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
index += hidden_size;
x_index += hidden_size;
auto o = sigmoid(h_in[index] + x[x_index]);
cell_state[elem] = f * cell[elem] + i * g;
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
""",
)
def lstm_custom(x, h_in, cell, time_step):
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
out_shape = cell.shape
return _lstm_kernel(
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
output_shapes=[out_shape, out_shape],
output_dtypes=[h_in.dtype, h_in.dtype],
grid=(x.shape[0], h_in.size // 4, 1),
threadgroup=(256, 1, 1),
)
class LSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.Wx = mx.zeros((4 * hidden_size, input_size))
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
def __call__(self, x, hidden=None, cell=None):
if self.bias is not None:
x = mx.addmm(self.bias, x, self.Wx.T)
else:
x = x @ self.Wx.T
all_hidden = []
B = x.shape[0]
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
for t in range(x.shape[-2]):
if hidden is None:
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
else:
hidden = hidden @ self.Wh.T
hidden, cell = lstm_custom(x, hidden, cell, t)
all_hidden.append(hidden)
return mx.stack(all_hidden, axis=-2)
class EncodecConv1d(nn.Module):
"""Conv1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.pad_mode = config.pad_mode
self.norm_type = config.norm_type
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, stride, dilation=dilation
)
if self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.stride = stride
# Effective kernel size with dilations.
self.kernel_size = (kernel_size - 1) * dilation + 1
self.padding_total = kernel_size - stride
def _get_extra_padding_for_conv1d(
self,
hidden_states: mx.array,
) -> mx.array:
length = hidden_states.shape[1]
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
n_frames = int(math.ceil(n_frames)) - 1
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
return ideal_length - length
def _pad1d(
self,
hidden_states: mx.array,
paddings: Tuple[int, int],
mode: str = "zero",
value: float = 0.0,
):
if mode != "reflect":
return mx.pad(
hidden_states, paddings, mode="constant", constant_values=value
)
length = hidden_states.shape[1]
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
def __call__(self, hidden_states):
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
if self.causal:
# Left padding for causal
hidden_states = self._pad1d(
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
)
else:
# Asymmetric padding required for odd strides
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
hidden_states = self._pad1d(
hidden_states,
(padding_left, padding_right + extra_padding),
mode=self.pad_mode,
)
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
return hidden_states
class EncodecConvTranspose1d(nn.Module):
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
def __init__(
self,
config,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
):
super().__init__()
self.causal = config.use_causal_conv
self.trim_right_ratio = config.trim_right_ratio
self.norm_type = config.norm_type
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
if config.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
self.padding_total = kernel_size - stride
def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
if self.norm_type == "time_group_norm":
hidden_states = self.norm(hidden_states)
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
end = hidden_states.shape[1] - padding_right
hidden_states = hidden_states[:, padding_left:end, :]
return hidden_states
class EncodecLSTM(nn.Module):
def __init__(self, config, dimension):
super().__init__()
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
def __call__(self, hidden_states):
h = hidden_states
for lstm in self.lstm:
h = lstm(h)
return h + hidden_states
class EncodecResnetBlock(nn.Module):
"""
Residual block from SEANet model as used by EnCodec.
"""
def __init__(self, config, dim: int, dilations: List[int]):
super().__init__()
kernel_sizes = (config.residual_kernel_size, 1)
if len(kernel_sizes) != len(dilations):
raise ValueError("Number of kernel sizes should match number of dilations")
hidden = dim // config.compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [nn.ELU()]
block += [
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
]
self.block = block
if getattr(config, "use_conv_shortcut", True):
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
else:
self.shortcut = nn.Identity()
def __call__(self, hidden_states):
residual = hidden_states
for layer in self.block:
hidden_states = layer(hidden_states)
return self.shortcut(residual) + hidden_states
class EncodecEncoder(nn.Module):
"""SEANet encoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
model = [
EncodecConv1d(
config, config.audio_channels, config.num_filters, config.kernel_size
)
]
scaling = 1
for ratio in reversed(config.upsampling_ratios):
current_scale = scaling * config.num_filters
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale, [config.dilation_growth_rate**j, 1]
)
]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
current_scale,
current_scale * 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
scaling *= 2
model += [EncodecLSTM(config, scaling * config.num_filters)]
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
scaling * config.num_filters,
config.hidden_size,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecDecoder(nn.Module):
"""SEANet decoder as used by EnCodec."""
def __init__(self, config):
super().__init__()
scaling = int(2 ** len(config.upsampling_ratios))
model = [
EncodecConv1d(
config,
config.hidden_size,
scaling * config.num_filters,
config.kernel_size,
)
]
model += [EncodecLSTM(config, scaling * config.num_filters)]
for ratio in config.upsampling_ratios:
current_scale = scaling * config.num_filters
model += [nn.ELU()]
model += [
EncodecConvTranspose1d(
config,
current_scale,
current_scale // 2,
kernel_size=ratio * 2,
stride=ratio,
)
]
for j in range(config.num_residual_layers):
model += [
EncodecResnetBlock(
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
)
]
scaling //= 2
model += [nn.ELU()]
model += [
EncodecConv1d(
config,
config.num_filters,
config.audio_channels,
config.last_kernel_size,
)
]
self.layers = model
def __call__(self, hidden_states):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class EncodecEuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance."""
def __init__(self, config):
super().__init__()
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
def quantize(self, hidden_states):
embed = self.embed.T
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
dist = -(
scaled_states
- 2 * hidden_states @ embed
+ embed.square().sum(axis=0, keepdims=True)
)
embed_ind = dist.argmax(axis=-1)
return embed_ind
def encode(self, hidden_states):
shape = hidden_states.shape
hidden_states = hidden_states.reshape((-1, shape[-1]))
embed_ind = self.quantize(hidden_states)
embed_ind = embed_ind.reshape(*shape[:-1])
return embed_ind
def decode(self, embed_ind):
return self.embed[embed_ind]
class EncodecVectorQuantization(nn.Module):
"""
Vector quantization implementation. Currently supports only euclidean distance.
"""
def __init__(self, config):
super().__init__()
self.codebook = EncodecEuclideanCodebook(config)
def encode(self, hidden_states):
return self.codebook.encode(hidden_states)
def decode(self, embed_ind):
return self.codebook.decode(embed_ind)
class EncodecResidualVectorQuantizer(nn.Module):
"""Residual Vector Quantizer."""
def __init__(self, config):
super().__init__()
self.codebook_size = config.codebook_size
hop_length = np.prod(config.upsampling_ratios)
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
self.num_quantizers = int(
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
)
self.layers = [
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
]
def get_num_quantizers_for_bandwidth(
self, bandwidth: Optional[float] = None
) -> int:
"""Return num_quantizers based on specified target bandwidth."""
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
num_quantizers = self.num_quantizers
if bandwidth is not None and bandwidth > 0.0:
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
return num_quantizers
def encode(
self, embeddings: mx.array, bandwidth: Optional[float] = None
) -> mx.array:
"""
Encode a given input array with the specified frame rate at the given
bandwidth. The RVQ encode method sets the appropriate number of
quantizers to use and returns indices for each quantizer.
"""
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
residual = embeddings
all_indices = []
for layer in self.layers[:num_quantizers]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = mx.stack(all_indices, axis=1)
return out_indices
def decode(self, codes: mx.array) -> mx.array:
"""Decode the given codes to the quantized representation."""
quantized_out = None
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
layer = self.layers[i]
quantized = layer.decode(indices.squeeze(1))
if quantized_out is None:
quantized_out = quantized
else:
quantized_out = quantized + quantized_out
return quantized_out
class EncodecModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = EncodecEncoder(config)
self.decoder = EncodecDecoder(config)
self.quantizer = EncodecResidualVectorQuantizer(config)
def _encode_frame(
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the given input using the underlying VQVAE.
"""
length = input_values.shape[1]
duration = length / self.config.sampling_rate
if (
self.config.chunk_length_s is not None
and duration > 1e-5 + self.config.chunk_length_s
):
raise RuntimeError(
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
)
scale = None
if self.config.normalize:
# if the padding is non zero
input_values = input_values * padding_mask[..., None]
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
input_values = input_values / scale
embeddings = self.encoder(input_values)
codes = self.quantizer.encode(embeddings, bandwidth)
return codes, scale
def encode(
self,
input_values: mx.array,
padding_mask: mx.array = None,
bandwidth: Optional[float] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (mx.array): The input audio waveform with shape
``(batch_size, channels, sequence_length)``.
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
bandwidth (float, optional): The target bandwidth. Must be one of
``config.target_bandwidths``. If ``None``, uses the smallest
possible bandwidth. bandwidth is represented as a thousandth of
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
Returns:
A list of frames containing the discrete encoded codes for the
input audio waveform, along with rescaling factors for each chunk
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
frames)``.
"""
if bandwidth is None:
bandwidth = self.config.target_bandwidths[0]
if bandwidth not in self.config.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.config.target_bandwidths}."
)
_, input_length, channels = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(
f"Number of audio channels must be 1 or 2, but got {channels}"
)
chunk_length = self.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.chunk_stride
if padding_mask is None:
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) != step:
raise ValueError(
"The input length is not properly padded for batched chunked "
"encoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
frame = input_values[:, offset : offset + chunk_length]
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = mx.stack(encoded_frames)
return (encoded_frames, scales)
@staticmethod
def _linear_overlap_add(frames: List[mx.array], stride: int):
if len(frames) == 0:
raise ValueError("`frames` cannot be an empty list.")
dtype = frames[0].dtype
N, frame_length, C = frames[0].shape
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
weight = 0.5 - (time_vec - 0.5).abs()
weight = weight[:, None]
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
out = mx.zeros((N, total_size, C), dtype=dtype)
offset = 0
for frame in frames:
frame_length = frame.shape[1]
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
sum_weight[offset : offset + frame_length] += weight[:frame_length]
offset += stride
return out / sum_weight
def _decode_frame(
self, codes: mx.array, scale: Optional[mx.array] = None
) -> mx.array:
embeddings = self.quantizer.decode(codes)
outputs = self.decoder(embeddings)
if scale is not None:
outputs = outputs * scale
return outputs
@property
def channels(self):
return self.config.audio_channels
@property
def sampling_rate(self):
return self.config.sampling_rate
@property
def chunk_length(self):
if self.config.chunk_length_s is None:
return None
else:
return int(self.config.chunk_length_s * self.config.sampling_rate)
@property
def chunk_stride(self):
if self.config.chunk_length_s is None or self.config.overlap is None:
return None
else:
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
def decode(
self,
audio_codes: mx.array,
audio_scales: Union[mx.array, List[mx.array]],
padding_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that
case, any extra steps at the end should be trimmed.
Args:
audio_codes (mx.array): Discret code embeddings of shape
``(batch_size, nb_chunks, chunk_length)``.
audio_scales (mx.array): Scaling factor for each input.
padding_mask (mx.array): Padding mask.
"""
chunk_length = self.chunk_length
if chunk_length is None:
if audio_codes.shape[1] != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
else:
decoded_frames = []
for frame, scale in zip(audio_codes, audio_scales):
frames = self._decode_frame(frame, scale)
decoded_frames.append(frames)
audio_values = self._linear_overlap_add(
decoded_frames, self.chunk_stride or 1
)
# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values

37
encodec/example.py Normal file
View File

@ -0,0 +1,37 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load, load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)
# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)
# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)
codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)
# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]
# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)

3
encodec/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
mlx>=0.18
numpy
huggingface_hub

66
encodec/test.py Normal file
View File

@ -0,0 +1,66 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import numpy as np
import torch
from datasets import Audio, load_dataset
from transformers import AutoProcessor, EncodecModel
from utils import load, load_audio, preprocess_audio
def compare_processors():
np.random.seed(0)
audio_length = 95500
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
)
mx_inputs = preprocess_audio(
mx.array(audio).T,
processor.sampling_rate,
processor.chunk_length,
processor.chunk_stride,
)
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
def compare_models():
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = load("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
mask = np.ones((1, audio_length), dtype=np.int32)
pt_encoded = pt_model.encode(
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
)
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
pt_codes = pt_encoded.audio_codes.numpy()
mx_codes = mx_encoded[0]
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
if mx_scale is not None:
pt_scale = pt_scale.numpy()
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
)
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
mx_audio = mx_audio.squeeze()
assert np.allclose(
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
), "Decoding audio mismatch"
if __name__ == "__main__":
compare_processors()
compare_models()

129
encodec/utils.py Normal file
View File

@ -0,0 +1,129 @@
# Copyright © 2024 Apple Inc.
import functools
import json
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Union
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import encodec
def save_audio(file: str, audio: mx.array, sampling_rate: int):
"""
Save audio to a wave (.wav) file.
"""
from scipy.io.wavfile import write
audio = (audio * 32767).astype(mx.int16)
write(file, sampling_rate, np.array(audio))
def load_audio(file: str, sampling_rate: int, channels: int):
"""
Read audio into an mx.array, resampling if necessary.
Args:
file (str): The audio file to open.
sampling_rate (int): The sample rate to resample the audio at if needed.
channels (int): The number of audio channels.
Returns:
An mx.array containing the audio waveform in float32.
"""
from subprocess import CalledProcessError, run
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", str(channels),
"-acodec", "pcm_s16le",
"-ar", str(sampling_rate),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)
def load(path_or_repo):
"""
Load the model and audo preprocessor.
"""
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = encodec.EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor

View File

@ -68,11 +68,10 @@ class LlavaModel(nn.Module):
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
if pixel_values is None:
return inputs_embeds
# Get the ouptut hidden states from the vision model
*_, hidden_states = self.vision_tower(

View File

@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
The `mlx-lm` package also has:
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
@ -29,7 +29,14 @@ from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
response = generate(model, tokenizer, prompt="hello", verbose=True)
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
To see a description of all the arguments you can do:
@ -38,7 +45,9 @@ To see a description of all the arguments you can do:
>>> help(generate)
```
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
Check out the [generation
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.
@ -77,6 +86,11 @@ model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
@ -122,10 +136,44 @@ mlx_lm.convert \
--upload-repo mlx-community/my-4bit-mistral
```
### Long Prompts and Generations
MLX LM has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
To use the rotating key-value cache pass the argument `--max-kv-size n` where
`n` can be any integer. Smaller values like `512` will use very little RAM but
result in worse quality. Larger values like `4096` or higher will use more RAM
but have better quality.
Caching prompts can substantially speedup reusing the same long context with
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
```bash
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--kv-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--kv-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
### Supported Models
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
models. If the model you want to run is not supported, file an
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.

26
llms/a.py Normal file
View File

@ -0,0 +1,26 @@
import mlx_lm
# model, tokenizer = mlx_lm.load("mlx-community/SmolLM-1.7B-Instruct-fp16")
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct")
draft_model, draft_tokenizer = mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
# https://github.com/hemingkx/Spec-Bench/blob/main/data/spec_bench/question.jsonl
prompt = "Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences."
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
mlx_lm.generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
max_tokens=500,
temp=1.0,
min_p=0.1,
repetition_penalty=1.2,
# draft_model=draft_model,
)

41
llms/b.py Normal file
View File

@ -0,0 +1,41 @@
import mlx_lm
import random
import string
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct")
capital_letters = string.ascii_uppercase
distinct_pairs = [
(a, b) for i, a in enumerate(capital_letters) for b in capital_letters[i + 1 :]
]
num_prompts = 16
prompt_template = "Think of a real word containing both the letters {l1} and {l2}. Then, say 3 sentences which use the word."
prompts = [
prompt_template.format(l1=p[0], l2=p[1])
for p in random.sample(distinct_pairs, num_prompts)
]
prompts = [
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
"James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?",
"Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?"
]
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
for prompt in prompts
]
response = mlx_lm.batch_generate(
model,
tokenizer,
prompts=prompts,
max_tokens=512,
verbose=True,
temp=1.0,
min_p=0.1,
repetition_penalty=1.2,
)

11
llms/c.py Normal file
View File

@ -0,0 +1,11 @@
import mlx_lm
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit")
for s in mlx_lm.stream_generate(
model,
tokenizer,
prompt="Meta Llama 3.1 is a ",
max_tokens=100,
):
print(s, end="", flush=True)

11
llms/d.py Normal file
View File

@ -0,0 +1,11 @@
import mlx_lm
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit")
for s in mlx_lm.stream_generate(
model,
tokenizer,
prompt=["Meta Llama 3.1 is a ", "Google Gemma 2 is a "],
max_tokens=20,
):
print(s[0].ljust(30) + s[1], flush=True)

21
llms/issue.txt Normal file
View File

@ -0,0 +1,21 @@
## Steps to reproduce
Run the following with and without `prefill_step_size=2` commented out:
```py
import mlx_lm
model, tokenizer = mlx_lm.load('/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit')
mlx_lm.generate(
model,
tokenizer,
prompt="69 + 420= ",
verbose=True,
max_tokens=10,
max_kv_size=5,
prefill_step_size=2,
)
```
The output is different. I notice that the RotatingKVCache has length 5 with prefill and length 7 without.

View File

@ -57,6 +57,9 @@ mlx_lm.lora \
--iters 600
```
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter config and weights are saved in `adapters/`. You can
specify the output location with `--adapter-path`.
By default, the adapter config and learned weights are saved in `adapters/`.
You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters/`, and save the fused
model in the path `lora_fused_model/`. All of these are configurable.
model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
@ -141,7 +144,7 @@ mlx_lm.fuse \
--export-gguf
```
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`.
## Data
@ -160,50 +163,86 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats:
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
data formats. Here are examples of these formats:
`chat`:
```jsonl
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
`tools`:
```jsonl
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
```
<details>
<summary>View the expanded single data tool format</summary>
```jsonl
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello."
},
{ "role": "user", "content": "What is the weather in San Francisco?" },
{
"role": "assistant",
"content": "How can I assistant you today."
"tool_calls": [
{
"id": "call_id",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
}
}
]
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and country, eg. San Francisco, USA"
},
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["location", "format"]
}
}
}
]
}
```
</details>
`completions`:
```jsonl
{
"prompt": "What is the capital of France?",
"completion": "Paris."
}
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
`text`:
```jsonl
{
"text": "This is an example for the model."
}
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
> one example per line and do not split an example accross multiple lines.
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
@ -212,7 +251,13 @@ To use Hugging Face datasets, first install the `datasets` package:
pip install datasets
```
Specify the Hugging Face dataset arguments in a YAML config. For example:
If the Hugging Face dataset is already in a supported format, you can specify
it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```
hf_dataset:
@ -231,11 +276,13 @@ hf_dataset:
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in
the `chat` example above with Hugging Face's default template becomes:
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
[chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
are used. This applies the model's chat template by default. If the model does
not have a chat template, then Hugging Face will use a default. For example,
the final text in the `chat` example above with Hugging Face's default template
becomes:
```text
<|im_start|>system
@ -263,7 +310,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
@ -285,7 +332,7 @@ mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--lora-layers 4 \
--num-layers 4 \
--data wikisql
```
@ -295,4 +342,5 @@ tokens-per-second, using the MLX Example
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

View File

@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
from .utils import convert, generate, load, stream_generate
from .version import __version__
from ._version import __version__
from .utils import convert, generate, load, stream_generate, batch_generate

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.17.1"
__version__ = "0.18.2"

View File

@ -56,7 +56,7 @@ def setup_arg_parser():
parser.add_argument(
"--max-kv-size",
type=int,
default=1024,
default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
@ -139,11 +139,15 @@ def main():
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0]
cache_dict[f"{i}_values"] = c.state[1]
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
if __name__ == "__main__":
main()

View File

@ -1,8 +1,12 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
@ -51,9 +55,6 @@ max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# Use DoRA instead of LoRA.
use_dora: false
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.

View File

@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
@ -77,16 +77,13 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = apply_lora_layers(model, args.adapter_path)
model = load_adapters(model, args.adapter_path)
fused_linears = [
(n, m.fuse())
for n, m in model.named_modules()
if isinstance(
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
)
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:

View File

@ -2,6 +2,7 @@
import argparse
import json
import sys
import mlx.core as mx
@ -12,7 +13,10 @@ DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_KV_SIZE = 1024
def str2bool(string):
return string.lower() not in ["false", "f"]
def setup_arg_parser():
@ -40,7 +44,9 @@ def setup_arg_parser():
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
"--prompt",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
@ -66,6 +72,12 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--verbose",
type=str2bool,
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--colorize",
action="store_true",
@ -81,6 +93,7 @@ def setup_arg_parser():
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--kv-cache-file",
@ -171,14 +184,19 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif tokenizer.chat_template is None:
elif cache_history is not None:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@ -195,29 +213,30 @@ def main():
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
if max_kv_size is None:
max_kv_size = (
int(metadata["max_kv_size"])
if cache_history is not None
else DEFAULT_MAX_KV_SIZE
)
if cache_history is not None:
max_kv_size = metadata["max_kv_size"]
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
generate(
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=True,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=max_kv_size,
cache_history=cache_history,
)
if not args.verbose:
print(response)
if __name__ == "__main__":

View File

@ -67,7 +67,7 @@ class HfVocab:
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
@ -77,9 +77,7 @@ class HfVocab:
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(
self.specials[text], b"", self.special_ids
)
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
@ -243,12 +241,15 @@ def prepare_metadata(config, vocab):
metadata["tokenizer.ggml.tokens"] = tokens
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
if vocab.tokenizer.bos_token_id is not None:
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
vocab.tokenizer.bos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.eos_token_id is not None:
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
vocab.tokenizer.eos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.unk_token_id is not None:
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
vocab.tokenizer.unk_token_id, dtype=mx.uint32
)

View File

@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
apply_lora_layers,
build_schedule,
linear_to_lora_layers,
load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
"lora_layers": 16,
"num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"use_dora": False,
}
@ -79,10 +79,20 @@ def build_parser():
parser.add_argument(
"--data",
type=str,
help="Directory with {train, valid, test}.jsonl files",
help=(
"Directory with {train, valid, test}.jsonl files or the name "
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
),
)
parser.add_argument(
"--lora-layers",
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
@ -107,12 +117,12 @@ def build_parser():
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training with the given adapters.",
help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the adapters.",
help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
@ -148,9 +158,6 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
)
return parser
@ -162,21 +169,31 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
# Freeze all layers
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
# Resume training the given adapters.
# Resume from weights if provided
if args.resume_adapter_file is not None:
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
@ -240,7 +257,7 @@ def run(args, training_callback: TrainingCallback = None):
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, args.adapter_path)
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")

231
llms/mlx_lm/models/mamba.py Normal file
View File

@ -0,0 +1,231 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
self.hidden_size = self.d_model
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
self.intermediate_size = self.d_inner
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
self.state_size = self.d_state
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
self.num_hidden_layers = self.n_layer
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
self.num_hidden_layers = self.n_layers
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
self.conv_kernel = self.d_conv
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
self.use_bias = self.bias
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
self.use_conv_bias = self.conv_bias
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaCache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
)
self.conv1d = DepthWiseConv1d(
channels=self.intermediate_size,
kernel_size=self.conv_kernel_size,
bias=self.use_conv_bias,
padding=self.conv_kernel_size - 1,
)
self.x_proj = nn.Linear(
self.intermediate_size,
self.time_step_rank + 2 * self.ssm_state_size,
bias=False,
)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
A = mx.repeat(
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
repeats=self.intermediate_size,
axis=0,
)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
self.out_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@ -0,0 +1,227 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
hidden_act: str
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
norm_eps: float
vocab_size: int
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
partial_rotary_factor: float = 0.5
rope_theta: float = 10000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@partial(mx.compile, shapeless=True)
def relu_squared(x):
return nn.relu(x).square()
class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.partial_rotary_factor = args.partial_rotary_factor
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
mlp_bias = args.mlp_bias
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(relu_squared(self.up_proj(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(
args.hidden_size, eps=args.norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class NemotronModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = NemotronModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -11,6 +11,7 @@ from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from .utils import generate_step, load
@ -596,6 +597,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
prompt = self.tokenizer.apply_chat_template(
body["messages"],
body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
@ -621,6 +623,46 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run(
host: str,

View File

@ -36,7 +36,10 @@ class ChatDataset(Dataset):
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
@ -73,17 +76,14 @@ class CompletionsDataset(Dataset):
return text
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
# Return empty dataset for non-existent paths
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
if "messages" in data[0]:
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in data[0] and "completion" in data[0]:
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif "text" in data[0]:
elif "text" in sample:
return Dataset(data)
else:
raise ValueError(
@ -92,8 +92,39 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
)
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
from datasets import exceptions, load_dataset
try:
dataset = load_dataset(data_id)
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
for n in names
]
except exceptions.DatasetNotFoundError:
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
return train, valid, test
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
hf_args = args.hf_dataset
@ -110,9 +141,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(
ds, tokenizer, prompt_feature, completion_feature
)
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
@ -133,13 +162,20 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
else:
test = []
else:
names = ("train", "valid", "test")
data_path = Path(args.data)
return train, valid, test
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."

View File

@ -14,10 +14,11 @@ class DoRALinear(nn.Module):
dropout: float = 0.0,
scale: float = 20.0,
):
# TODO support quantized weights in DoRALinear
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
raise ValueError("DoRALinear does not yet support quantization.")
input_dims *= 32 // linear.bits
dora_lin = DoRALinear(
input_dims=input_dims,
output_dims=output_dims,
@ -31,13 +32,13 @@ class DoRALinear(nn.Module):
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
weight = self._dequantized_weight()
# Use the same type as the linear weight if not quantized
# Use the same type as the linear weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
@ -47,6 +48,13 @@ class DoRALinear(nn.Module):
if bias:
fused_linear.bias = linear.bias
if self._is_quantized() and not de_quantize:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
@ -76,22 +84,45 @@ class DoRALinear(nn.Module):
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def set_linear(self, linear: nn.Linear):
def set_linear(self, linear):
"""
Set the self.linear layer and recompute self.m.
"""
self.linear = linear
self.m = mx.linalg.norm(self.linear.weight, axis=1)
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
def _dequantized_weight(self):
"""
Return the weight of linear layer and dequantize it if is quantized
"""
weight = self.linear.weight
if self._is_quantized():
weight = mx.dequantize(
weight,
self.linear.scales,
self.linear.biases,
self.linear.group_size,
self.linear.bits,
)
return weight
def _is_quantized(self):
return isinstance(self.linear, nn.QuantizedLinear)
def __call__(self, x):
# Regular LoRA (without a bias)
y = x @ self.linear.weight.T
w = self._dequantized_weight()
y = x @ w.T
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom) * out
out = (self.m / denom).astype(x.dtype) * out
if "bias" in self.linear:
out = out + self.linear.bias

View File

@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
@ -93,9 +95,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] == tokenizer.eos_token_id:
print("[WARNING] Example already has an EOS token appended")
else:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch]
@ -287,24 +287,18 @@ def train(
# Save adapter weights
if it % args.steps_per_save == 0:
save_adapter(model, args.adapter_file)
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
save_adapter(model, checkpoint)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# save final adapter weights
save_adapter(model, args.adapter_file)
print(f"Saved final adapter weights to {args.adapter_file}.")
def save_adapter(
model: nn.Module,
adapter_file: Union[str, Path],
):
flattened_tree = tree_flatten(model.trainable_parameters())
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")

View File

@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
def linear_to_lora_layers(
model: nn.Module,
num_lora_layers: int,
num_layers: int,
config: Dict,
use_dora: bool = False,
):
@ -45,23 +45,17 @@ def linear_to_lora_layers(
Args:
model (nn.Module): The neural network model.
num_lora_layers (int): The number of blocks to convert to lora layers
num_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
num_layers = len(model.layers)
if num_lora_layers < 0:
num_lora_layers = num_layers
if num_lora_layers > num_layers:
if num_layers > len(model.layers):
raise ValueError(
f"Requested {num_lora_layers} LoRA layers "
f"but the model only has {num_layers} layers."
f"Requested {num_layers} LoRA layers "
f"but the model only has {len(model.layers)} layers."
)
def to_lora(layer):
@ -93,6 +87,7 @@ def linear_to_lora_layers(
"llama",
"phi",
"mixtral",
"nemotron",
"stablelm",
"qwen2",
"qwen2_moe",
@ -139,10 +134,19 @@ def linear_to_lora_layers(
"self_attn.kv_b_proj",
]
)
elif model.model_type == "mamba":
keys = set(
[
"mixer.in_proj",
"mixer.x_proj",
"mixer.dt_proj",
"mixer.out_proj",
]
)
else:
raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]:
for l in model.layers[-min(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
@ -152,9 +156,9 @@ def linear_to_lora_layers(
model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
"""
Apply LoRA layers to the model.
Load any fine-tuned adapters / layers.
Args:
model (nn.Module): The neural network model.
@ -168,11 +172,13 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
fine_tune_type = getattr(config, "fine_tune_type", "lora")
if fine_tune_type != "full":
linear_to_lora_layers(
model,
config.lora_layers,
config.num_layers,
config.lora_parameters,
getattr(config, "use_dora", False),
use_dora=(fine_tune_type == "dora"),
)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model

View File

@ -14,7 +14,6 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
@ -22,8 +21,8 @@ from transformers import PreTrainedTokenizer
from .models.base import KVCache, RotatingKVCache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
# Constants
MODEL_REMAPPING = {
@ -91,7 +90,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
],
)
)
except RepositoryNotFoundError:
except:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
"Please make sure you specified the local path or Hugging Face"
@ -102,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
@ -110,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
Args:
logits (mx.array): The logits produced by the language model.
generated_tokens (any): A list of N previous tokens.
tokens (mx.array): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = generated_tokens
selected_logits = mx.take_along_axis(logits, indices, axis=-1)
if len(tokens) > 0:
selected_logits = mx.take_along_axis(logits, tokens, axis=-1)
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits
logits[mx.arange(tokens.shape[0])[:, None], tokens] = selected_logits
return logits
@ -155,16 +153,17 @@ def generate_step(
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
logit_bias: Optional[Dict[int, float]] = None,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``.
prompts (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
@ -178,10 +177,13 @@ def generate_step(
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
logit_bias (dictionary, optional): Additive logit bias.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@ -195,10 +197,6 @@ def generate_step(
)
def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
if temp == 0:
@ -220,7 +218,29 @@ def generate_step(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
logits_processor = logits_processor or []
if repetition_penalty:
def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processor.append(logit_bias_processor)
y = prompts
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
@ -235,28 +255,18 @@ def generate_step(
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
repetition_context = prompts
if repetition_context_size:
repetition_context = repetition_context[:, -repetition_context_size:]
def _step(y):
nonlocal repetition_context
logits = model(y, cache=cache)
logits = logits[:, -1, :]
if repetition_penalty:
logits = apply_repetition_penalty(
logits, repetition_context, repetition_penalty
)
y, logprobs = sample(logits)
repetition_context = mx.concatenate([repetition_context, y], axis=-1)
else:
y, logprobs = sample(logits)
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y], axis=-1) if tokens is not None else y
if repetition_context_size:
if repetition_context.shape[1] > repetition_context_size:
repetition_context = repetition_context[:, -repetition_context_size:]
for processor in logits_processor:
logits = processor(tokens, logits)
y, logprobs = sample(logits)
return y, logprobs
while y.shape[1] > prefill_step_size:
@ -265,6 +275,7 @@ def generate_step(
y = y[:, prefill_step_size:]
y, logprobs = _step(y)
mx.async_eval(y)
while True:
next_y, next_logprobs = _step(y)
@ -280,7 +291,7 @@ def stream_generate(
prompt: str,
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Generator[str, None, None]:
"""
A generator producing text based on the given prompt from the model.
@ -320,19 +331,19 @@ def stream_generate(
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[str]],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, List[str]]:
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompts (str): The string prompt(s).
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
@ -341,46 +352,31 @@ def generate(
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
is_batch = isinstance(prompt, list)
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if is_batch:
tokenizer._tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer._tokenizer.pad_token = tokenizer.eos_token
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
prompt_tokens = mx.array(
tokenizer._tokenizer(prompt, padding=True)["input_ids"]
)
output_toks = []
else:
prompt_tokens = mx.array(tokenizer.encode(prompt))[None]
detokenizer = tokenizer.detokenizer
detokenizer.reset()
if verbose:
print("=" * 10)
print("Prompt:", prompt)
tic = time.perf_counter()
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
for (tokens, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
tic = time.perf_counter()
detokenizer.reset()
for (token, logprobs), n in zip(
generate_step(prompt_tokens[None], model, **kwargs),
range(max_tokens),
):
token = token.item()
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if (tokens == tokenizer.eos_token_id).all():
if token == tokenizer.eos_token_id:
break
if is_batch:
output_toks.append(tokens)
if verbose:
print(".", end="", flush=True)
else:
token = tokens.item()
logprobs = logprobs.squeeze(0)
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
@ -389,35 +385,101 @@ def generate(
else:
print(detokenizer.last_segment, end="", flush=True)
if is_batch:
token_count = n + 1
detokenizer.finalize()
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
def batch_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompts: List[str],
max_tokens: int = 100,
verbose: bool = False,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompts (List[str]): The string prompts.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if kwargs.get("max_kv_size", None) is not None:
raise ValueError("max_kv_size is not supported for batch generation")
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
tokenizer._tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer._tokenizer.pad_token = tokenizer.eos_token
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
prompt_tokens = mx.array(
tokenizer._tokenizer(prompts, padding=True)["input_ids"]
)
output_toks = []
tic = time.perf_counter()
for (tokens, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if (tokens == tokenizer.eos_token_id).all():
break
output_toks.append(tokens)
if verbose:
print(".", end="", flush=True)
output_toks = mx.concatenate(output_toks, axis=1)
token_count = output_toks.size
response = [
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
for response in tokenizer.batch_decode(output_toks.tolist())
]
else:
token_count = n
detokenizer.finalize()
response = detokenizer.text
if verbose:
gen_time = time.perf_counter() - tic
if token_count <= 0:
print("No tokens generated for this prompt")
if is_batch:
else:
print()
for p, resp in zip(prompt, response):
for p, resp in zip(prompts, response):
print("=" * 10)
print("Prompt:", p)
print(resp)
else:
print(detokenizer.last_segment, flush=True)
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = token_count / gen_time
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return response
@ -539,7 +601,7 @@ def load(
model = load_model(model_path, lazy, model_config)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
@ -596,6 +658,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
@ -612,7 +675,16 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
response = generate(model, tokenizer, prompt="hello", verbose=True)
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
@ -702,6 +774,8 @@ def quantize_model(
quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config

View File

@ -10,7 +10,7 @@ with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
from version import __version__
from _version import __version__
setup(
name="mlx-lm",

View File

@ -11,7 +11,7 @@ import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.tuner.dora import DoRAEmbedding
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@ -164,6 +164,147 @@ class TestDora(unittest.TestCase):
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
def test_llama(self):
from mlx_lm.models import llama
hidden_size = 1024
intermediate_size = 2048
args = llama.ModelArgs(
model_type="llama",
hidden_size=hidden_size,
num_hidden_layers=4,
intermediate_size=intermediate_size,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
dora_layers = 4
def check_config(params):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
model = llama.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
self.assertEqual(
trainable_params,
dora_layers
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
params["rank"] = 1
check_config(params)
params["keys"] = ["self_attn.k_proj"]
check_config(params)
def test_dora_m_parameter(self):
dora_lin = DoRALinear(input_dims=100, output_dims=100)
self.assertTrue(
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
)
# Recomputes m when changing Linear
inital_m = dora_lin.m
lin = nn.Linear(10, 10)
dora_lin.set_linear(lin)
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
# Works with quantized weights
quantized_linear = nn.QuantizedLinear(512, 512)
dora_lin.set_linear(quantized_linear)
dequantized_weight = mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
self.assertTrue(
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
)
def test_dora_from_linear(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims)
dora_lin = DoRALinear.from_base(linear, r)
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
self.assertEqual(dora_lin.m.shape, (out_dims,))
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
dequantized_weight = mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
self.assertTrue(
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
)
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
def test_dora_to_linear(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims, bias=True)
dora_lin = DoRALinear.from_base(linear, r)
to_linear = dora_lin.fuse()
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
def dequantize_weight(quantized_linear):
return mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
# Dequantize
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
self.assertTrue(
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
)
self.assertTrue(
mx.allclose(
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
)
)
def test_dora_dtype(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims, bias=True)
linear.set_dtype(mx.float16)
dora_lin = DoRALinear.from_base(linear, r)
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
self.assertEqual(dora_lin(x).dtype, mx.float16)
class TestScheduleConfig(unittest.TestCase):
def test_join(self):

View File

@ -0,0 +1,55 @@
# Copyright © 2024 Apple Inc.
import unittest
from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
def test_generate(self):
# Simple test that generation runs
text = generate(
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
)
def test_generate_with_logit_bias(self):
logit_bias = {0: 2000.0, 1: -20.0}
text = generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logit_bias=logit_bias,
)
self.assertEqual(text, "!!!!!")
def test_generate_with_processor(self):
init_toks = self.tokenizer.encode("hello")
all_toks = None
def logits_processor(toks, logits):
nonlocal all_toks
all_toks = toks
return logits
generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logits_processor=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
if __name__ == "__main__":
unittest.main()

View File

@ -5,6 +5,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache, RotatingKVCache
from mlx_lm.utils import make_kv_caches
class TestModels(unittest.TestCase):
@ -100,13 +101,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
cache = make_kv_caches(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@ -397,6 +392,26 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mamba(self):
from mlx_lm.models import mamba
args = mamba.ModelArgs(
model_type="mamba",
vocab_size=10000,
use_bias=False,
use_conv_bias=True,
conv_kernel=4,
hidden_size=768,
num_hidden_layers=24,
state_size=16,
intermediate_size=1536,
time_step_rank=48,
)
model = mamba.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt2(self):
from mlx_lm.models import gpt2

View File

@ -1,5 +1,7 @@
# Copyright © 2024 Apple Inc.
import http
import json
import threading
import unittest
@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url)
self.assertEqual(response.status_code, 200)
response_body = json.loads(response.text)
self.assertEqual(response_body["object"], "list")
self.assertIsInstance(response_body["data"], list)
self.assertGreater(len(response_body["data"]), 0)
model = response_body["data"][0]
self.assertIn("id", model)
self.assertEqual(model["object"], "model")
self.assertIn("created", model)
def test_sequence_overlap(self):
from mlx_lm.server import sequence_overlap

View File

@ -35,6 +35,8 @@ _MODELS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
@ -52,6 +54,8 @@ _ALIGNMENT_HEADS = {
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
}

View File

@ -1,5 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from . import audio, decoding, load_models
from ._version import __version__
from .transcribe import transcribe
from .version import __version__

View File

@ -12,7 +12,7 @@ with open(package_dir / "requirements.txt") as fid:
sys.path.append(str(package_dir))
from version import __version__
from _version import __version__
setup(
name="mlx-whisper",