mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:36:37 +08:00
Merge branch 'main' into feat/batch_generate
This commit is contained in:
commit
8fb82fee43
@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Markus Enzweiler: Added the `cvae` examples.
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||||
- Shiyu Li: Added the `Segment Anything Model`.
|
- Shiyu Li: Added the `Segment Anything Model`.
|
||||||
|
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
|
@ -27,6 +27,7 @@ Some more useful examples are listed below.
|
|||||||
### Audio Models
|
### Audio Models
|
||||||
|
|
||||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||||
|
- Audio compression and generation with [Meta's EnCodec](encodec).
|
||||||
|
|
||||||
### Multimodal models
|
### Multimodal models
|
||||||
|
|
||||||
|
83
encodec/README.md
Normal file
83
encodec/README.md
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# EnCodec
|
||||||
|
|
||||||
|
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
|
||||||
|
generate audio.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Install the requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Optionally install FFmpeg and SciPy for loading and saving audio files,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
Install [FFmpeg](https://ffmpeg.org/):
|
||||||
|
|
||||||
|
```
|
||||||
|
# on macOS using Homebrew (https://brew.sh/)
|
||||||
|
brew install ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
Install SciPy:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install scipy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
An example using the model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load, load_audio, save_audio
|
||||||
|
|
||||||
|
# Load the 48 KHz model and preprocessor.
|
||||||
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
# Load an audio file
|
||||||
|
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
|
||||||
|
|
||||||
|
# Preprocess the audio (this can also be a list of arrays for batched
|
||||||
|
# processing).
|
||||||
|
feats, mask = processor(audio)
|
||||||
|
|
||||||
|
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||||
|
# compression but lower reconstruction quality.
|
||||||
|
@mx.compile
|
||||||
|
def encode(feats, mask):
|
||||||
|
return model.encode(feats, mask, bandwidth=3)
|
||||||
|
|
||||||
|
# Decode to reconstruct the audio
|
||||||
|
@mx.compile
|
||||||
|
def decode(codes, scales, mask):
|
||||||
|
return model.decode(codes, scales, mask)
|
||||||
|
|
||||||
|
|
||||||
|
codes, scales = encode(feats, mask)
|
||||||
|
reconstructed = decode(codes, scales, mask)
|
||||||
|
|
||||||
|
# Trim any padding:
|
||||||
|
reconstructed = reconstructed[0, : len(audio)]
|
||||||
|
|
||||||
|
# Save the audio as a wave file
|
||||||
|
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
||||||
|
```
|
||||||
|
|
||||||
|
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
|
||||||
|
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
|
||||||
|
in several data types.
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
|
||||||
|
To convert models, use the `convert.py` script. To see the options, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert.py -h
|
||||||
|
```
|
||||||
|
|
||||||
|
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
|
||||||
|
[code](https://github.com/facebookresearch/encodec) for more details.
|
30
encodec/benchmarks/bench_mx.py
Normal file
30
encodec/benchmarks/bench_mx.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load
|
||||||
|
|
||||||
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
audio = mx.random.uniform(shape=(288000, 2))
|
||||||
|
feats, mask = processor(audio)
|
||||||
|
mx.eval(model, feats, mask)
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun():
|
||||||
|
codes, scales = model.encode(feats, mask, bandwidth=3)
|
||||||
|
reconstructed = model.decode(codes, scales, mask)
|
||||||
|
return reconstructed
|
||||||
|
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fun())
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(10):
|
||||||
|
mx.eval(fun())
|
||||||
|
toc = time.time()
|
||||||
|
ms = 1000 * (toc - tic) / 10
|
||||||
|
print(f"Time per it: {ms:.3f}")
|
34
encodec/benchmarks/bench_pt.py
Normal file
34
encodec/benchmarks/bench_pt.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, EncodecModel
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
||||||
|
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
|
||||||
|
|
||||||
|
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
|
||||||
|
pt_inputs = processor(
|
||||||
|
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
||||||
|
).to("mps")
|
||||||
|
|
||||||
|
|
||||||
|
def fun():
|
||||||
|
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
|
||||||
|
pt_audio = pt_model.decode(
|
||||||
|
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
|
||||||
|
)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
fun()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
for _ in range(10):
|
||||||
|
fun()
|
||||||
|
toc = time.time()
|
||||||
|
ms = 1000 * (toc - tic) / 10
|
||||||
|
print(f"Time per it: {ms:.3f}")
|
213
encodec/convert.py
Normal file
213
encodec/convert.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from textwrap import dedent
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
import encodec
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_hub(hf_repo: str) -> Path:
|
||||||
|
model_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=hf_repo,
|
||||||
|
allow_patterns=["*.json", "*.safetensors"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||||
|
"""
|
||||||
|
Uploads the model to Hugging Face hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Local path to the model.
|
||||||
|
upload_repo (str): Name of the HF repo to upload to.
|
||||||
|
hf_path (str): Path to the original Hugging Face model.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, ModelCard, logging
|
||||||
|
|
||||||
|
content = dedent(
|
||||||
|
f"""
|
||||||
|
---
|
||||||
|
language: en
|
||||||
|
license: other
|
||||||
|
library: mlx
|
||||||
|
tags:
|
||||||
|
- mlx
|
||||||
|
---
|
||||||
|
|
||||||
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||||
|
converted to MLX format from
|
||||||
|
[{hf_path}](https://huggingface.co/{hf_path}).
|
||||||
|
|
||||||
|
This model is intended to be used with the [EnCodec MLX
|
||||||
|
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
card = ModelCard(content)
|
||||||
|
card.save(os.path.join(path, "README.md"))
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=path,
|
||||||
|
repo_id=upload_repo,
|
||||||
|
repo_type="model",
|
||||||
|
multi_commits=True,
|
||||||
|
multi_commits_verbose=True,
|
||||||
|
)
|
||||||
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
||||||
|
|
||||||
|
|
||||||
|
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||||
|
if isinstance(save_path, str):
|
||||||
|
save_path = Path(save_path)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
total_size = sum(v.nbytes for v in weights.values())
|
||||||
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||||
|
mx.save_safetensors(
|
||||||
|
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
|
||||||
|
)
|
||||||
|
|
||||||
|
for weight_name in weights.keys():
|
||||||
|
index_data["weight_map"][weight_name] = "model.safetensors"
|
||||||
|
|
||||||
|
index_data["weight_map"] = {
|
||||||
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||||
|
json.dump(index_data, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(
|
||||||
|
config: dict,
|
||||||
|
config_path: Union[str, Path],
|
||||||
|
) -> None:
|
||||||
|
"""Save the model configuration to the ``config_path``.
|
||||||
|
|
||||||
|
The final configuration will be sorted before saving for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The model configuration.
|
||||||
|
config_path (Union[str, Path]): Model configuration file path.
|
||||||
|
"""
|
||||||
|
# Clean unused keys
|
||||||
|
config.pop("_name_or_path", None)
|
||||||
|
|
||||||
|
# sort the config for better readability
|
||||||
|
config = dict(sorted(config.items()))
|
||||||
|
|
||||||
|
# write the updated config to the config_path (if provided)
|
||||||
|
with open(config_path, "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
upload: bool,
|
||||||
|
model: str,
|
||||||
|
dtype: str = None,
|
||||||
|
):
|
||||||
|
hf_repo = f"facebook/encodec_{model}"
|
||||||
|
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
|
||||||
|
path = fetch_from_hub(hf_repo)
|
||||||
|
save_path = Path("mlx_models")
|
||||||
|
|
||||||
|
weights = mx.load(str(Path(path) / "model.safetensors"))
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as fid:
|
||||||
|
config = SimpleNamespace(**json.load(fid))
|
||||||
|
|
||||||
|
model = encodec.EncodecModel(config)
|
||||||
|
|
||||||
|
new_weights = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
basename, pname = k.rsplit(".", 1)
|
||||||
|
if pname == "weight_v":
|
||||||
|
g = weights[basename + ".weight_g"]
|
||||||
|
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
|
||||||
|
k = basename + ".weight"
|
||||||
|
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
|
||||||
|
continue
|
||||||
|
elif "lstm" in basename:
|
||||||
|
w_or_b, ih_or_hh, ln = pname.split("_")
|
||||||
|
if w_or_b == "weight":
|
||||||
|
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
|
||||||
|
elif w_or_b == "bias" and ih_or_hh == "ih":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
v = v + weights[k.replace("_hh_", "_ih_")]
|
||||||
|
new_pname = "bias"
|
||||||
|
k = basename + "." + ln[1:] + "." + new_pname
|
||||||
|
if "conv.weight" in k:
|
||||||
|
# Possibly a transposed conv which has a different order
|
||||||
|
if "decoder" in k:
|
||||||
|
ln = int(k.split(".")[2])
|
||||||
|
if "conv" in model.decoder.layers[ln] and isinstance(
|
||||||
|
model.decoder.layers[ln].conv, nn.ConvTranspose1d
|
||||||
|
):
|
||||||
|
v = mx.moveaxis(v, 0, 2)
|
||||||
|
else:
|
||||||
|
v = mx.moveaxis(v, 1, 2)
|
||||||
|
else:
|
||||||
|
v = mx.moveaxis(v, 1, 2)
|
||||||
|
|
||||||
|
new_weights[k] = v
|
||||||
|
weights = new_weights
|
||||||
|
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
t = getattr(mx, dtype)
|
||||||
|
weights = {k: v.astype(t) for k, v in weights.items()}
|
||||||
|
|
||||||
|
if isinstance(save_path, str):
|
||||||
|
save_path = Path(save_path)
|
||||||
|
|
||||||
|
save_weights(save_path, weights)
|
||||||
|
|
||||||
|
save_config(vars(config), config_path=save_path / "config.json")
|
||||||
|
|
||||||
|
if upload:
|
||||||
|
upload_to_hub(save_path, mlx_repo, hf_repo)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="48khz",
|
||||||
|
help="",
|
||||||
|
choices=["24khz", "32khz", "48khz"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upload",
|
||||||
|
action="store_true",
|
||||||
|
help="Upload the weights to Hugging Face.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
help="Data type to convert the model to.",
|
||||||
|
default="float32",
|
||||||
|
choices=["float32", "bfloat16", "float16"],
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert(upload=args.upload, model=args.model, dtype=args.dtype)
|
671
encodec/encodec.py
Normal file
671
encodec/encodec.py
Normal file
@ -0,0 +1,671 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_lstm_kernel = mx.fast.metal_kernel(
|
||||||
|
name="lstm",
|
||||||
|
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
|
||||||
|
output_names=["hidden_state", "cell_state"],
|
||||||
|
header="""
|
||||||
|
template <typename T>
|
||||||
|
T sigmoid(T x) {
|
||||||
|
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||||
|
return (x < 0) ? 1 - y : y;
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
source="""
|
||||||
|
uint b = thread_position_in_grid.x;
|
||||||
|
uint d = hidden_size * 4;
|
||||||
|
|
||||||
|
uint elem = b * d + thread_position_in_grid.y;
|
||||||
|
uint index = elem;
|
||||||
|
uint x_index = b * num_time_steps * d + time_step * d + index;
|
||||||
|
|
||||||
|
auto i = sigmoid(h_in[index] + x[x_index]);
|
||||||
|
index += hidden_size;
|
||||||
|
x_index += hidden_size;
|
||||||
|
auto f = sigmoid(h_in[index] + x[x_index]);
|
||||||
|
index += hidden_size;
|
||||||
|
x_index += hidden_size;
|
||||||
|
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
|
||||||
|
index += hidden_size;
|
||||||
|
x_index += hidden_size;
|
||||||
|
auto o = sigmoid(h_in[index] + x[x_index]);
|
||||||
|
|
||||||
|
cell_state[elem] = f * cell[elem] + i * g;
|
||||||
|
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lstm_custom(x, h_in, cell, time_step):
|
||||||
|
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
|
||||||
|
out_shape = cell.shape
|
||||||
|
return _lstm_kernel(
|
||||||
|
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
|
||||||
|
output_shapes=[out_shape, out_shape],
|
||||||
|
output_dtypes=[h_in.dtype, h_in.dtype],
|
||||||
|
grid=(x.shape[0], h_in.size // 4, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LSTM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.Wx = mx.zeros((4 * hidden_size, input_size))
|
||||||
|
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
|
||||||
|
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
|
||||||
|
|
||||||
|
def __call__(self, x, hidden=None, cell=None):
|
||||||
|
if self.bias is not None:
|
||||||
|
x = mx.addmm(self.bias, x, self.Wx.T)
|
||||||
|
else:
|
||||||
|
x = x @ self.Wx.T
|
||||||
|
|
||||||
|
all_hidden = []
|
||||||
|
|
||||||
|
B = x.shape[0]
|
||||||
|
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
|
||||||
|
for t in range(x.shape[-2]):
|
||||||
|
if hidden is None:
|
||||||
|
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
|
||||||
|
else:
|
||||||
|
hidden = hidden @ self.Wh.T
|
||||||
|
hidden, cell = lstm_custom(x, hidden, cell, t)
|
||||||
|
all_hidden.append(hidden)
|
||||||
|
|
||||||
|
return mx.stack(all_hidden, axis=-2)
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecConv1d(nn.Module):
|
||||||
|
"""Conv1d with asymmetric or causal padding and normalization."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causal = config.use_causal_conv
|
||||||
|
self.pad_mode = config.pad_mode
|
||||||
|
self.norm_type = config.norm_type
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels, out_channels, kernel_size, stride, dilation=dilation
|
||||||
|
)
|
||||||
|
if self.norm_type == "time_group_norm":
|
||||||
|
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# Effective kernel size with dilations.
|
||||||
|
self.kernel_size = (kernel_size - 1) * dilation + 1
|
||||||
|
|
||||||
|
self.padding_total = kernel_size - stride
|
||||||
|
|
||||||
|
def _get_extra_padding_for_conv1d(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
length = hidden_states.shape[1]
|
||||||
|
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
|
||||||
|
n_frames = int(math.ceil(n_frames)) - 1
|
||||||
|
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
|
||||||
|
return ideal_length - length
|
||||||
|
|
||||||
|
def _pad1d(
|
||||||
|
self,
|
||||||
|
hidden_states: mx.array,
|
||||||
|
paddings: Tuple[int, int],
|
||||||
|
mode: str = "zero",
|
||||||
|
value: float = 0.0,
|
||||||
|
):
|
||||||
|
if mode != "reflect":
|
||||||
|
return mx.pad(
|
||||||
|
hidden_states, paddings, mode="constant", constant_values=value
|
||||||
|
)
|
||||||
|
|
||||||
|
length = hidden_states.shape[1]
|
||||||
|
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
|
||||||
|
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
|
||||||
|
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
# Left padding for causal
|
||||||
|
hidden_states = self._pad1d(
|
||||||
|
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Asymmetric padding required for odd strides
|
||||||
|
padding_right = self.padding_total // 2
|
||||||
|
padding_left = self.padding_total - padding_right
|
||||||
|
hidden_states = self._pad1d(
|
||||||
|
hidden_states,
|
||||||
|
(padding_left, padding_right + extra_padding),
|
||||||
|
mode=self.pad_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
|
||||||
|
if self.norm_type == "time_group_norm":
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecConvTranspose1d(nn.Module):
|
||||||
|
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.causal = config.use_causal_conv
|
||||||
|
self.trim_right_ratio = config.trim_right_ratio
|
||||||
|
self.norm_type = config.norm_type
|
||||||
|
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
||||||
|
if config.norm_type == "time_group_norm":
|
||||||
|
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||||
|
self.padding_total = kernel_size - stride
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
|
||||||
|
if self.norm_type == "time_group_norm":
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
||||||
|
else:
|
||||||
|
padding_right = self.padding_total // 2
|
||||||
|
|
||||||
|
padding_left = self.padding_total - padding_right
|
||||||
|
|
||||||
|
end = hidden_states.shape[1] - padding_right
|
||||||
|
hidden_states = hidden_states[:, padding_left:end, :]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecLSTM(nn.Module):
|
||||||
|
def __init__(self, config, dimension):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
h = hidden_states
|
||||||
|
for lstm in self.lstm:
|
||||||
|
h = lstm(h)
|
||||||
|
return h + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecResnetBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block from SEANet model as used by EnCodec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, dim: int, dilations: List[int]):
|
||||||
|
super().__init__()
|
||||||
|
kernel_sizes = (config.residual_kernel_size, 1)
|
||||||
|
if len(kernel_sizes) != len(dilations):
|
||||||
|
raise ValueError("Number of kernel sizes should match number of dilations")
|
||||||
|
|
||||||
|
hidden = dim // config.compress
|
||||||
|
block = []
|
||||||
|
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
||||||
|
in_chs = dim if i == 0 else hidden
|
||||||
|
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
||||||
|
block += [nn.ELU()]
|
||||||
|
block += [
|
||||||
|
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
|
||||||
|
]
|
||||||
|
self.block = block
|
||||||
|
|
||||||
|
if getattr(config, "use_conv_shortcut", True):
|
||||||
|
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
residual = hidden_states
|
||||||
|
for layer in self.block:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
|
||||||
|
return self.shortcut(residual) + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecEncoder(nn.Module):
|
||||||
|
"""SEANet encoder as used by EnCodec."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
model = [
|
||||||
|
EncodecConv1d(
|
||||||
|
config, config.audio_channels, config.num_filters, config.kernel_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scaling = 1
|
||||||
|
|
||||||
|
for ratio in reversed(config.upsampling_ratios):
|
||||||
|
current_scale = scaling * config.num_filters
|
||||||
|
for j in range(config.num_residual_layers):
|
||||||
|
model += [
|
||||||
|
EncodecResnetBlock(
|
||||||
|
config, current_scale, [config.dilation_growth_rate**j, 1]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
current_scale,
|
||||||
|
current_scale * 2,
|
||||||
|
kernel_size=ratio * 2,
|
||||||
|
stride=ratio,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scaling *= 2
|
||||||
|
|
||||||
|
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
scaling * config.num_filters,
|
||||||
|
config.hidden_size,
|
||||||
|
config.last_kernel_size,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = model
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecDecoder(nn.Module):
|
||||||
|
"""SEANet decoder as used by EnCodec."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
scaling = int(2 ** len(config.upsampling_ratios))
|
||||||
|
model = [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
config.hidden_size,
|
||||||
|
scaling * config.num_filters,
|
||||||
|
config.kernel_size,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
||||||
|
|
||||||
|
for ratio in config.upsampling_ratios:
|
||||||
|
current_scale = scaling * config.num_filters
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConvTranspose1d(
|
||||||
|
config,
|
||||||
|
current_scale,
|
||||||
|
current_scale // 2,
|
||||||
|
kernel_size=ratio * 2,
|
||||||
|
stride=ratio,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for j in range(config.num_residual_layers):
|
||||||
|
model += [
|
||||||
|
EncodecResnetBlock(
|
||||||
|
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
scaling //= 2
|
||||||
|
|
||||||
|
model += [nn.ELU()]
|
||||||
|
model += [
|
||||||
|
EncodecConv1d(
|
||||||
|
config,
|
||||||
|
config.num_filters,
|
||||||
|
config.audio_channels,
|
||||||
|
config.last_kernel_size,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.layers = model
|
||||||
|
|
||||||
|
def __call__(self, hidden_states):
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecEuclideanCodebook(nn.Module):
|
||||||
|
"""Codebook with Euclidean distance."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
|
||||||
|
|
||||||
|
def quantize(self, hidden_states):
|
||||||
|
embed = self.embed.T
|
||||||
|
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
|
||||||
|
dist = -(
|
||||||
|
scaled_states
|
||||||
|
- 2 * hidden_states @ embed
|
||||||
|
+ embed.square().sum(axis=0, keepdims=True)
|
||||||
|
)
|
||||||
|
embed_ind = dist.argmax(axis=-1)
|
||||||
|
return embed_ind
|
||||||
|
|
||||||
|
def encode(self, hidden_states):
|
||||||
|
shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.reshape((-1, shape[-1]))
|
||||||
|
embed_ind = self.quantize(hidden_states)
|
||||||
|
embed_ind = embed_ind.reshape(*shape[:-1])
|
||||||
|
return embed_ind
|
||||||
|
|
||||||
|
def decode(self, embed_ind):
|
||||||
|
return self.embed[embed_ind]
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecVectorQuantization(nn.Module):
|
||||||
|
"""
|
||||||
|
Vector quantization implementation. Currently supports only euclidean distance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.codebook = EncodecEuclideanCodebook(config)
|
||||||
|
|
||||||
|
def encode(self, hidden_states):
|
||||||
|
return self.codebook.encode(hidden_states)
|
||||||
|
|
||||||
|
def decode(self, embed_ind):
|
||||||
|
return self.codebook.decode(embed_ind)
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecResidualVectorQuantizer(nn.Module):
|
||||||
|
"""Residual Vector Quantizer."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.codebook_size = config.codebook_size
|
||||||
|
|
||||||
|
hop_length = np.prod(config.upsampling_ratios)
|
||||||
|
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
|
||||||
|
self.num_quantizers = int(
|
||||||
|
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
|
||||||
|
)
|
||||||
|
self.layers = [
|
||||||
|
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_num_quantizers_for_bandwidth(
|
||||||
|
self, bandwidth: Optional[float] = None
|
||||||
|
) -> int:
|
||||||
|
"""Return num_quantizers based on specified target bandwidth."""
|
||||||
|
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
|
||||||
|
num_quantizers = self.num_quantizers
|
||||||
|
if bandwidth is not None and bandwidth > 0.0:
|
||||||
|
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
|
||||||
|
return num_quantizers
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, embeddings: mx.array, bandwidth: Optional[float] = None
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Encode a given input array with the specified frame rate at the given
|
||||||
|
bandwidth. The RVQ encode method sets the appropriate number of
|
||||||
|
quantizers to use and returns indices for each quantizer.
|
||||||
|
"""
|
||||||
|
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
|
||||||
|
residual = embeddings
|
||||||
|
all_indices = []
|
||||||
|
for layer in self.layers[:num_quantizers]:
|
||||||
|
indices = layer.encode(residual)
|
||||||
|
quantized = layer.decode(indices)
|
||||||
|
residual = residual - quantized
|
||||||
|
all_indices.append(indices)
|
||||||
|
out_indices = mx.stack(all_indices, axis=1)
|
||||||
|
return out_indices
|
||||||
|
|
||||||
|
def decode(self, codes: mx.array) -> mx.array:
|
||||||
|
"""Decode the given codes to the quantized representation."""
|
||||||
|
quantized_out = None
|
||||||
|
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
quantized = layer.decode(indices.squeeze(1))
|
||||||
|
if quantized_out is None:
|
||||||
|
quantized_out = quantized
|
||||||
|
else:
|
||||||
|
quantized_out = quantized + quantized_out
|
||||||
|
return quantized_out
|
||||||
|
|
||||||
|
|
||||||
|
class EncodecModel(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.encoder = EncodecEncoder(config)
|
||||||
|
self.decoder = EncodecDecoder(config)
|
||||||
|
self.quantizer = EncodecResidualVectorQuantizer(config)
|
||||||
|
|
||||||
|
def _encode_frame(
|
||||||
|
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
|
||||||
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
|
"""
|
||||||
|
Encodes the given input using the underlying VQVAE.
|
||||||
|
"""
|
||||||
|
length = input_values.shape[1]
|
||||||
|
duration = length / self.config.sampling_rate
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config.chunk_length_s is not None
|
||||||
|
and duration > 1e-5 + self.config.chunk_length_s
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = None
|
||||||
|
if self.config.normalize:
|
||||||
|
# if the padding is non zero
|
||||||
|
input_values = input_values * padding_mask[..., None]
|
||||||
|
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
|
||||||
|
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
|
||||||
|
input_values = input_values / scale
|
||||||
|
|
||||||
|
embeddings = self.encoder(input_values)
|
||||||
|
codes = self.quantizer.encode(embeddings, bandwidth)
|
||||||
|
return codes, scale
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
input_values: mx.array,
|
||||||
|
padding_mask: mx.array = None,
|
||||||
|
bandwidth: Optional[float] = None,
|
||||||
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
|
"""
|
||||||
|
Encodes the input audio waveform into discrete codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_values (mx.array): The input audio waveform with shape
|
||||||
|
``(batch_size, channels, sequence_length)``.
|
||||||
|
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
|
||||||
|
bandwidth (float, optional): The target bandwidth. Must be one of
|
||||||
|
``config.target_bandwidths``. If ``None``, uses the smallest
|
||||||
|
possible bandwidth. bandwidth is represented as a thousandth of
|
||||||
|
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of frames containing the discrete encoded codes for the
|
||||||
|
input audio waveform, along with rescaling factors for each chunk
|
||||||
|
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
|
||||||
|
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
|
||||||
|
frames)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if bandwidth is None:
|
||||||
|
bandwidth = self.config.target_bandwidths[0]
|
||||||
|
if bandwidth not in self.config.target_bandwidths:
|
||||||
|
raise ValueError(
|
||||||
|
f"This model doesn't support the bandwidth {bandwidth}. "
|
||||||
|
f"Select one of {self.config.target_bandwidths}."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, input_length, channels = input_values.shape
|
||||||
|
|
||||||
|
if channels < 1 or channels > 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of audio channels must be 1 or 2, but got {channels}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_length = self.chunk_length
|
||||||
|
if chunk_length is None:
|
||||||
|
chunk_length = input_length
|
||||||
|
stride = input_length
|
||||||
|
else:
|
||||||
|
stride = self.chunk_stride
|
||||||
|
|
||||||
|
if padding_mask is None:
|
||||||
|
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
|
||||||
|
encoded_frames = []
|
||||||
|
scales = []
|
||||||
|
|
||||||
|
step = chunk_length - stride
|
||||||
|
if (input_length % stride) != step:
|
||||||
|
raise ValueError(
|
||||||
|
"The input length is not properly padded for batched chunked "
|
||||||
|
"encoding. Make sure to pad the input correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
for offset in range(0, input_length - step, stride):
|
||||||
|
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
|
||||||
|
frame = input_values[:, offset : offset + chunk_length]
|
||||||
|
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
|
||||||
|
encoded_frames.append(encoded_frame)
|
||||||
|
scales.append(scale)
|
||||||
|
|
||||||
|
encoded_frames = mx.stack(encoded_frames)
|
||||||
|
|
||||||
|
return (encoded_frames, scales)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _linear_overlap_add(frames: List[mx.array], stride: int):
|
||||||
|
if len(frames) == 0:
|
||||||
|
raise ValueError("`frames` cannot be an empty list.")
|
||||||
|
|
||||||
|
dtype = frames[0].dtype
|
||||||
|
N, frame_length, C = frames[0].shape
|
||||||
|
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
|
||||||
|
|
||||||
|
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
||||||
|
weight = 0.5 - (time_vec - 0.5).abs()
|
||||||
|
|
||||||
|
weight = weight[:, None]
|
||||||
|
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
|
||||||
|
out = mx.zeros((N, total_size, C), dtype=dtype)
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
for frame in frames:
|
||||||
|
frame_length = frame.shape[1]
|
||||||
|
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
|
||||||
|
sum_weight[offset : offset + frame_length] += weight[:frame_length]
|
||||||
|
offset += stride
|
||||||
|
|
||||||
|
return out / sum_weight
|
||||||
|
|
||||||
|
def _decode_frame(
|
||||||
|
self, codes: mx.array, scale: Optional[mx.array] = None
|
||||||
|
) -> mx.array:
|
||||||
|
embeddings = self.quantizer.decode(codes)
|
||||||
|
outputs = self.decoder(embeddings)
|
||||||
|
if scale is not None:
|
||||||
|
outputs = outputs * scale
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return self.config.audio_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_rate(self):
|
||||||
|
return self.config.sampling_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_length(self):
|
||||||
|
if self.config.chunk_length_s is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return int(self.config.chunk_length_s * self.config.sampling_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_stride(self):
|
||||||
|
if self.config.chunk_length_s is None or self.config.overlap is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
audio_codes: mx.array,
|
||||||
|
audio_scales: Union[mx.array, List[mx.array]],
|
||||||
|
padding_mask: Optional[mx.array] = None,
|
||||||
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Decodes the given frames into an output audio waveform.
|
||||||
|
|
||||||
|
Note that the output might be a bit bigger than the input. In that
|
||||||
|
case, any extra steps at the end should be trimmed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_codes (mx.array): Discret code embeddings of shape
|
||||||
|
``(batch_size, nb_chunks, chunk_length)``.
|
||||||
|
audio_scales (mx.array): Scaling factor for each input.
|
||||||
|
padding_mask (mx.array): Padding mask.
|
||||||
|
"""
|
||||||
|
chunk_length = self.chunk_length
|
||||||
|
if chunk_length is None:
|
||||||
|
if audio_codes.shape[1] != 1:
|
||||||
|
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
||||||
|
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
|
||||||
|
else:
|
||||||
|
decoded_frames = []
|
||||||
|
|
||||||
|
for frame, scale in zip(audio_codes, audio_scales):
|
||||||
|
frames = self._decode_frame(frame, scale)
|
||||||
|
decoded_frames.append(frames)
|
||||||
|
|
||||||
|
audio_values = self._linear_overlap_add(
|
||||||
|
decoded_frames, self.chunk_stride or 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# truncate based on padding mask
|
||||||
|
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
|
||||||
|
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||||
|
return audio_values
|
37
encodec/example.py
Normal file
37
encodec/example.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load, load_audio, save_audio
|
||||||
|
|
||||||
|
# Load the 48 KHz model and preprocessor.
|
||||||
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
# Load an audio file
|
||||||
|
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||||
|
|
||||||
|
# Preprocess the audio (this can also be a list of arrays for batched
|
||||||
|
# processing).
|
||||||
|
feats, mask = processor(audio)
|
||||||
|
|
||||||
|
|
||||||
|
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||||
|
# compression but lower reconstruction quality.
|
||||||
|
@mx.compile
|
||||||
|
def encode(feats, mask):
|
||||||
|
return model.encode(feats, mask, bandwidth=3)
|
||||||
|
|
||||||
|
|
||||||
|
# Decode to reconstruct the audio
|
||||||
|
@mx.compile
|
||||||
|
def decode(codes, scales, mask):
|
||||||
|
return model.decode(codes, scales, mask)
|
||||||
|
|
||||||
|
|
||||||
|
codes, scales = encode(feats, mask)
|
||||||
|
reconstructed = decode(codes, scales, mask)
|
||||||
|
|
||||||
|
# Trim any padding:
|
||||||
|
reconstructed = reconstructed[0, : len(audio)]
|
||||||
|
|
||||||
|
# Save the audio as a wave file
|
||||||
|
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
3
encodec/requirements.txt
Normal file
3
encodec/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mlx>=0.18
|
||||||
|
numpy
|
||||||
|
huggingface_hub
|
66
encodec/test.py
Normal file
66
encodec/test.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import Audio, load_dataset
|
||||||
|
from transformers import AutoProcessor, EncodecModel
|
||||||
|
from utils import load, load_audio, preprocess_audio
|
||||||
|
|
||||||
|
|
||||||
|
def compare_processors():
|
||||||
|
np.random.seed(0)
|
||||||
|
audio_length = 95500
|
||||||
|
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
||||||
|
|
||||||
|
pt_inputs = processor(
|
||||||
|
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
||||||
|
)
|
||||||
|
mx_inputs = preprocess_audio(
|
||||||
|
mx.array(audio).T,
|
||||||
|
processor.sampling_rate,
|
||||||
|
processor.chunk_length,
|
||||||
|
processor.chunk_stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
|
||||||
|
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
|
||||||
|
|
||||||
|
|
||||||
|
def compare_models():
|
||||||
|
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||||
|
mx_model, _ = load("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
audio_length = 190560
|
||||||
|
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
|
||||||
|
mask = np.ones((1, audio_length), dtype=np.int32)
|
||||||
|
pt_encoded = pt_model.encode(
|
||||||
|
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
|
||||||
|
)
|
||||||
|
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
|
||||||
|
pt_codes = pt_encoded.audio_codes.numpy()
|
||||||
|
mx_codes = mx_encoded[0]
|
||||||
|
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
|
||||||
|
|
||||||
|
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
|
||||||
|
if mx_scale is not None:
|
||||||
|
pt_scale = pt_scale.numpy()
|
||||||
|
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
|
||||||
|
|
||||||
|
pt_audio = pt_model.decode(
|
||||||
|
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
|
||||||
|
)
|
||||||
|
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
|
||||||
|
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
|
||||||
|
mx_audio = mx_audio.squeeze()
|
||||||
|
assert np.allclose(
|
||||||
|
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
|
||||||
|
), "Decoding audio mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
compare_processors()
|
||||||
|
compare_models()
|
129
encodec/utils.py
Normal file
129
encodec/utils.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
import encodec
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||||
|
"""
|
||||||
|
Save audio to a wave (.wav) file.
|
||||||
|
"""
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
|
audio = (audio * 32767).astype(mx.int16)
|
||||||
|
write(file, sampling_rate, np.array(audio))
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file: str, sampling_rate: int, channels: int):
|
||||||
|
"""
|
||||||
|
Read audio into an mx.array, resampling if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str): The audio file to open.
|
||||||
|
sampling_rate (int): The sample rate to resample the audio at if needed.
|
||||||
|
channels (int): The number of audio channels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An mx.array containing the audio waveform in float32.
|
||||||
|
"""
|
||||||
|
from subprocess import CalledProcessError, run
|
||||||
|
|
||||||
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
|
# fmt: off
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostdin",
|
||||||
|
"-threads", "0",
|
||||||
|
"-i", file,
|
||||||
|
"-f", "s16le",
|
||||||
|
"-ac", str(channels),
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sampling_rate),
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
try:
|
||||||
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
except CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
out = mx.array(np.frombuffer(out, np.int16))
|
||||||
|
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio(
|
||||||
|
raw_audio: Union[mx.array, List[mx.array]],
|
||||||
|
sampling_rate: int = 24000,
|
||||||
|
chunk_length: Optional[int] = None,
|
||||||
|
chunk_stride: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Prepare inputs for the EnCodec model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||||
|
sequences to be processed.
|
||||||
|
sampling_rate (int): The sampling rate at which the audio waveform
|
||||||
|
should be digitalized.
|
||||||
|
chunk_length (int, optional): The model's chunk length.
|
||||||
|
chunk_stride (int, optional): The model's chunk stride.
|
||||||
|
"""
|
||||||
|
if not isinstance(raw_audio, list):
|
||||||
|
raw_audio = [raw_audio]
|
||||||
|
|
||||||
|
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||||
|
|
||||||
|
max_length = max(array.shape[0] for array in raw_audio)
|
||||||
|
if chunk_length is not None:
|
||||||
|
max_length += chunk_length - (max_length % chunk_stride)
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
masks = []
|
||||||
|
for x in raw_audio:
|
||||||
|
length = x.shape[0]
|
||||||
|
mask = mx.ones((length,), dtype=mx.bool_)
|
||||||
|
difference = max_length - length
|
||||||
|
if difference > 0:
|
||||||
|
mask = mx.pad(mask, (0, difference))
|
||||||
|
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||||
|
inputs.append(x)
|
||||||
|
masks.append(mask)
|
||||||
|
return mx.stack(inputs), mx.stack(masks)
|
||||||
|
|
||||||
|
|
||||||
|
def load(path_or_repo):
|
||||||
|
"""
|
||||||
|
Load the model and audo preprocessor.
|
||||||
|
"""
|
||||||
|
path = Path(path_or_repo)
|
||||||
|
if not path.exists():
|
||||||
|
path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_repo,
|
||||||
|
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as f:
|
||||||
|
config = SimpleNamespace(**json.load(f))
|
||||||
|
|
||||||
|
model = encodec.EncodecModel(config)
|
||||||
|
model.load_weights(str(path / "model.safetensors"))
|
||||||
|
processor = functools.partial(
|
||||||
|
preprocess_audio,
|
||||||
|
sampling_rate=config.sampling_rate,
|
||||||
|
chunk_length=model.chunk_length,
|
||||||
|
chunk_stride=model.chunk_stride,
|
||||||
|
)
|
||||||
|
mx.eval(model)
|
||||||
|
return model, processor
|
@ -68,11 +68,10 @@ class LlavaModel(nn.Module):
|
|||||||
input_ids: Optional[mx.array] = None,
|
input_ids: Optional[mx.array] = None,
|
||||||
pixel_values: Optional[mx.array] = None,
|
pixel_values: Optional[mx.array] = None,
|
||||||
):
|
):
|
||||||
if pixel_values is None:
|
|
||||||
return self.language_model(input_ids)
|
|
||||||
|
|
||||||
# Get the input embeddings from the language model
|
# Get the input embeddings from the language model
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is None:
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
# Get the ouptut hidden states from the vision model
|
# Get the ouptut hidden states from the vision model
|
||||||
*_, hidden_states = self.vision_tower(
|
*_, hidden_states = self.vision_tower(
|
||||||
|
@ -16,7 +16,7 @@ conda install -c conda-forge mlx-lm
|
|||||||
|
|
||||||
The `mlx-lm` package also has:
|
The `mlx-lm` package also has:
|
||||||
|
|
||||||
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||||
|
|
||||||
@ -29,7 +29,14 @@ from mlx_lm import load, generate
|
|||||||
|
|
||||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||||
|
|
||||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
prompt = "Write a story about Einstein"
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
To see a description of all the arguments you can do:
|
To see a description of all the arguments you can do:
|
||||||
@ -38,7 +45,9 @@ To see a description of all the arguments you can do:
|
|||||||
>>> help(generate)
|
>>> help(generate)
|
||||||
```
|
```
|
||||||
|
|
||||||
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
|
Check out the [generation
|
||||||
|
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
|
||||||
|
to see how to use the API in more detail.
|
||||||
|
|
||||||
The `mlx-lm` package also comes with functionality to quantize and optionally
|
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||||
upload models to the Hugging Face Hub.
|
upload models to the Hugging Face Hub.
|
||||||
@ -77,6 +86,11 @@ model, tokenizer = load(repo)
|
|||||||
|
|
||||||
prompt = "Write a story about Einstein"
|
prompt = "Write a story about Einstein"
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
print(t, end="", flush=True)
|
print(t, end="", flush=True)
|
||||||
print()
|
print()
|
||||||
@ -122,10 +136,44 @@ mlx_lm.convert \
|
|||||||
--upload-repo mlx-community/my-4bit-mistral
|
--upload-repo mlx-community/my-4bit-mistral
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Long Prompts and Generations
|
||||||
|
|
||||||
|
MLX LM has some tools to scale efficiently to long prompts and generations:
|
||||||
|
|
||||||
|
- A rotating fixed-size key-value cache.
|
||||||
|
- Prompt caching
|
||||||
|
|
||||||
|
To use the rotating key-value cache pass the argument `--max-kv-size n` where
|
||||||
|
`n` can be any integer. Smaller values like `512` will use very little RAM but
|
||||||
|
result in worse quality. Larger values like `4096` or higher will use more RAM
|
||||||
|
but have better quality.
|
||||||
|
|
||||||
|
Caching prompts can substantially speedup reusing the same long context with
|
||||||
|
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cat prompt.txt | mlx_lm.cache_prompt \
|
||||||
|
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||||
|
--prompt - \
|
||||||
|
--kv-cache-file mistral_prompt.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use the cached prompt with `mlx_lm.generate`:
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_lm.generate \
|
||||||
|
--kv-cache-file mistral_prompt.safetensors \
|
||||||
|
--prompt "\nSummarize the above text."
|
||||||
|
```
|
||||||
|
|
||||||
|
The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
||||||
|
when using a cached prompt, the model to use is read from the cache and need
|
||||||
|
not be supplied explicitly.
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
|
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
||||||
models. If the model you want to run is not supported, file an
|
run is not supported, file an
|
||||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||||
submit a pull request.
|
submit a pull request.
|
||||||
|
|
||||||
|
26
llms/a.py
Normal file
26
llms/a.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import mlx_lm
|
||||||
|
|
||||||
|
# model, tokenizer = mlx_lm.load("mlx-community/SmolLM-1.7B-Instruct-fp16")
|
||||||
|
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct")
|
||||||
|
draft_model, draft_tokenizer = mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
|
||||||
|
|
||||||
|
# https://github.com/hemingkx/Spec-Bench/blob/main/data/spec_bench/question.jsonl
|
||||||
|
prompt = "Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences."
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mlx_lm.generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
verbose=True,
|
||||||
|
max_tokens=500,
|
||||||
|
temp=1.0,
|
||||||
|
min_p=0.1,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
# draft_model=draft_model,
|
||||||
|
)
|
41
llms/b.py
Normal file
41
llms/b.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import mlx_lm
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct")
|
||||||
|
|
||||||
|
capital_letters = string.ascii_uppercase
|
||||||
|
distinct_pairs = [
|
||||||
|
(a, b) for i, a in enumerate(capital_letters) for b in capital_letters[i + 1 :]
|
||||||
|
]
|
||||||
|
|
||||||
|
num_prompts = 16
|
||||||
|
prompt_template = "Think of a real word containing both the letters {l1} and {l2}. Then, say 3 sentences which use the word."
|
||||||
|
prompts = [
|
||||||
|
prompt_template.format(l1=p[0], l2=p[1])
|
||||||
|
for p in random.sample(distinct_pairs, num_prompts)
|
||||||
|
]
|
||||||
|
prompts = [
|
||||||
|
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
|
||||||
|
"James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?",
|
||||||
|
"Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?"
|
||||||
|
]
|
||||||
|
prompts = [
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
response = mlx_lm.batch_generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompts=prompts,
|
||||||
|
max_tokens=512,
|
||||||
|
verbose=True,
|
||||||
|
temp=1.0,
|
||||||
|
min_p=0.1,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
)
|
11
llms/c.py
Normal file
11
llms/c.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import mlx_lm
|
||||||
|
|
||||||
|
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit")
|
||||||
|
|
||||||
|
for s in mlx_lm.stream_generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt="Meta Llama 3.1 is a ",
|
||||||
|
max_tokens=100,
|
||||||
|
):
|
||||||
|
print(s, end="", flush=True)
|
11
llms/d.py
Normal file
11
llms/d.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import mlx_lm
|
||||||
|
|
||||||
|
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit")
|
||||||
|
|
||||||
|
for s in mlx_lm.stream_generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt=["Meta Llama 3.1 is a ", "Google Gemma 2 is a "],
|
||||||
|
max_tokens=20,
|
||||||
|
):
|
||||||
|
print(s[0].ljust(30) + s[1], flush=True)
|
21
llms/issue.txt
Normal file
21
llms/issue.txt
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
## Steps to reproduce
|
||||||
|
|
||||||
|
Run the following with and without `prefill_step_size=2` commented out:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import mlx_lm
|
||||||
|
|
||||||
|
model, tokenizer = mlx_lm.load('/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit')
|
||||||
|
|
||||||
|
mlx_lm.generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt="69 + 420= ",
|
||||||
|
verbose=True,
|
||||||
|
max_tokens=10,
|
||||||
|
max_kv_size=5,
|
||||||
|
prefill_step_size=2,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The output is different. I notice that the RotatingKVCache has length 5 with prefill and length 7 without.
|
@ -57,6 +57,9 @@ mlx_lm.lora \
|
|||||||
--iters 600
|
--iters 600
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
|
||||||
|
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
|
||||||
|
|
||||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||||
details on the data format see the section on [Data](#Data).
|
details on the data format see the section on [Data](#Data).
|
||||||
@ -67,8 +70,8 @@ mistralai/Mistral-7B-v0.1`.
|
|||||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||||
otherwise it will use regular LoRA.
|
otherwise it will use regular LoRA.
|
||||||
|
|
||||||
By default, the adapter config and weights are saved in `adapters/`. You can
|
By default, the adapter config and learned weights are saved in `adapters/`.
|
||||||
specify the output location with `--adapter-path`.
|
You can specify the output location with `--adapter-path`.
|
||||||
|
|
||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
@ -118,7 +121,7 @@ mlx_lm.fuse --model <path_to_model>
|
|||||||
```
|
```
|
||||||
|
|
||||||
This will by default load the adapters from `adapters/`, and save the fused
|
This will by default load the adapters from `adapters/`, and save the fused
|
||||||
model in the path `lora_fused_model/`. All of these are configurable.
|
model in the path `fused_model/`. All of these are configurable.
|
||||||
|
|
||||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||||
@ -141,7 +144,7 @@ mlx_lm.fuse \
|
|||||||
--export-gguf
|
--export-gguf
|
||||||
```
|
```
|
||||||
|
|
||||||
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
|
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
|
||||||
can specify the file name with `--gguf-path`.
|
can specify the file name with `--gguf-path`.
|
||||||
|
|
||||||
## Data
|
## Data
|
||||||
@ -160,50 +163,86 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
|||||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||||
loader expects a `test.jsonl` in the data directory.
|
loader expects a `test.jsonl` in the data directory.
|
||||||
|
|
||||||
Currently, `*.jsonl` files support three data formats: `chat`,
|
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
|
||||||
`completions`, and `text`. Here are three examples of these formats:
|
data formats. Here are examples of these formats:
|
||||||
|
|
||||||
`chat`:
|
`chat`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
`tools`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>View the expanded single data tool format</summary>
|
||||||
|
|
||||||
```jsonl
|
```jsonl
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{ "role": "user", "content": "What is the weather in San Francisco?" },
|
||||||
"role": "system",
|
{
|
||||||
"content": "You are a helpful assistant."
|
"role": "assistant",
|
||||||
},
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"id": "call_id",
|
||||||
"content": "Hello."
|
"type": "function",
|
||||||
},
|
"function": {
|
||||||
{
|
"name": "get_current_weather",
|
||||||
"role": "assistant",
|
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
|
||||||
"content": "How can I assistant you today."
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and country, eg. San Francisco, USA"
|
||||||
|
},
|
||||||
|
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
`completions`:
|
`completions`:
|
||||||
|
|
||||||
```jsonl
|
```jsonl
|
||||||
{
|
{"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||||
"prompt": "What is the capital of France?",
|
|
||||||
"completion": "Paris."
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`text`:
|
`text`:
|
||||||
|
|
||||||
```jsonl
|
```jsonl
|
||||||
{
|
{"text": "This is an example for the model."}
|
||||||
"text": "This is an example for the model."
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Note, the format is automatically determined by the dataset. Note also, keys in
|
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||||
each line not expected by the loader will be ignored.
|
each line not expected by the loader will be ignored.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Each example in the datasets must be on a single line. Do not put more than
|
||||||
|
> one example per line and do not split an example accross multiple lines.
|
||||||
|
|
||||||
### Hugging Face Datasets
|
### Hugging Face Datasets
|
||||||
|
|
||||||
To use Hugging Face datasets, first install the `datasets` package:
|
To use Hugging Face datasets, first install the `datasets` package:
|
||||||
@ -212,7 +251,13 @@ To use Hugging Face datasets, first install the `datasets` package:
|
|||||||
pip install datasets
|
pip install datasets
|
||||||
```
|
```
|
||||||
|
|
||||||
Specify the Hugging Face dataset arguments in a YAML config. For example:
|
If the Hugging Face dataset is already in a supported format, you can specify
|
||||||
|
it on the command line. For example, pass `--data mlx-community/wikisql` to
|
||||||
|
train on the pre-formatted WikiwSQL data.
|
||||||
|
|
||||||
|
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
|
||||||
|
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
|
||||||
|
example:
|
||||||
|
|
||||||
```
|
```
|
||||||
hf_dataset:
|
hf_dataset:
|
||||||
@ -231,11 +276,13 @@ hf_dataset:
|
|||||||
- Arguments specified in `config` will be passed as keyword arguments to
|
- Arguments specified in `config` will be passed as keyword arguments to
|
||||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||||
|
|
||||||
In general, for the `chat` and `completions` formats, Hugging Face [chat
|
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
|
||||||
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
[chat
|
||||||
the model's chat template by default. If the model does not have a chat
|
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||||
template, then Hugging Face will use a default. For example, the final text in
|
are used. This applies the model's chat template by default. If the model does
|
||||||
the `chat` example above with Hugging Face's default template becomes:
|
not have a chat template, then Hugging Face will use a default. For example,
|
||||||
|
the final text in the `chat` example above with Hugging Face's default template
|
||||||
|
becomes:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
<|im_start|>system
|
<|im_start|>system
|
||||||
@ -263,7 +310,7 @@ of memory. Here are some tips to reduce memory use should you need to do so:
|
|||||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||||
things down a little, but will also reduce the memory use.
|
things down a little, but will also reduce the memory use.
|
||||||
|
|
||||||
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
|
||||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||||
needed for back propagation. It may also reduce the quality of the
|
needed for back propagation. It may also reduce the quality of the
|
||||||
fine-tuned model if you are fine-tuning with a lot of data.
|
fine-tuned model if you are fine-tuning with a lot of data.
|
||||||
@ -285,7 +332,7 @@ mlx_lm.lora \
|
|||||||
--model mistralai/Mistral-7B-v0.1 \
|
--model mistralai/Mistral-7B-v0.1 \
|
||||||
--train \
|
--train \
|
||||||
--batch-size 1 \
|
--batch-size 1 \
|
||||||
--lora-layers 4 \
|
--num-layers 4 \
|
||||||
--data wikisql
|
--data wikisql
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -295,4 +342,5 @@ tokens-per-second, using the MLX Example
|
|||||||
data set.
|
data set.
|
||||||
|
|
||||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||||
|
|
||||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||||
|
@ -85,3 +85,17 @@ curl localhost:8080/v1/chat/completions \
|
|||||||
|
|
||||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||||
rlative to the directory the server was started in.
|
rlative to the directory the server was started in.
|
||||||
|
|
||||||
|
### List Models
|
||||||
|
|
||||||
|
Use the `v1/models` endpoint to list available models:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||||
|
```
|
||||||
|
|
||||||
|
This will return a list of locally available models where each model in the
|
||||||
|
list contains the following fields:
|
||||||
|
|
||||||
|
- `"id"`: The Hugging Face repo id.
|
||||||
|
- `"created"`: A timestamp representing the model creation time.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from .utils import convert, generate, load, stream_generate
|
from ._version import __version__
|
||||||
from .version import __version__
|
from .utils import convert, generate, load, stream_generate, batch_generate
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.17.1"
|
__version__ = "0.18.2"
|
@ -56,7 +56,7 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=None,
|
||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -139,11 +139,15 @@ def main():
|
|||||||
print("Saving...")
|
print("Saving...")
|
||||||
cache_dict = {}
|
cache_dict = {}
|
||||||
for i, c in enumerate(cache):
|
for i, c in enumerate(cache):
|
||||||
cache_dict[f"{i}_keys"] = c.state[0]
|
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
|
||||||
cache_dict[f"{i}_values"] = c.state[1]
|
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = tokenizer.chat_template
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
# The path to the local model directory or Hugging Face repo.
|
# The path to the local model directory or Hugging Face repo.
|
||||||
model: "mlx_model"
|
model: "mlx_model"
|
||||||
|
|
||||||
# Whether or not to train (boolean)
|
# Whether or not to train (boolean)
|
||||||
train: true
|
train: true
|
||||||
|
|
||||||
|
# The fine-tuning method: "lora", "dora", or "full".
|
||||||
|
fine_tune_type: lora
|
||||||
|
|
||||||
# Directory with {train, valid, test}.jsonl files
|
# Directory with {train, valid, test}.jsonl files
|
||||||
data: "/path/to/training/data"
|
data: "/path/to/training/data"
|
||||||
|
|
||||||
@ -51,9 +55,6 @@ max_seq_length: 2048
|
|||||||
# Use gradient checkpointing to reduce memory use.
|
# Use gradient checkpointing to reduce memory use.
|
||||||
grad_checkpoint: false
|
grad_checkpoint: false
|
||||||
|
|
||||||
# Use DoRA instead of LoRA.
|
|
||||||
use_dora: false
|
|
||||||
|
|
||||||
# LoRA parameters can only be specified in a config file
|
# LoRA parameters can only be specified in a config file
|
||||||
lora_parameters:
|
lora_parameters:
|
||||||
# The layer keys to apply LoRA to.
|
# The layer keys to apply LoRA to.
|
||||||
|
@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
|
|||||||
from .gguf import convert_to_gguf
|
from .gguf import convert_to_gguf
|
||||||
from .tuner.dora import DoRAEmbedding, DoRALinear
|
from .tuner.dora import DoRAEmbedding, DoRALinear
|
||||||
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||||
from .tuner.utils import apply_lora_layers, dequantize
|
from .tuner.utils import dequantize, load_adapters
|
||||||
from .utils import (
|
from .utils import (
|
||||||
fetch_from_hub,
|
fetch_from_hub,
|
||||||
get_model_path,
|
get_model_path,
|
||||||
@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-path",
|
"--save-path",
|
||||||
default="lora_fused_model",
|
default="fused_model",
|
||||||
help="The path to save the fused model.",
|
help="The path to save the fused model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -77,17 +77,14 @@ def main() -> None:
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path)
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
model = apply_lora_layers(model, args.adapter_path)
|
model = load_adapters(model, args.adapter_path)
|
||||||
|
|
||||||
fused_linears = [
|
fused_linears = [
|
||||||
(n, m.fuse())
|
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||||
for n, m in model.named_modules()
|
|
||||||
if isinstance(
|
|
||||||
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
model.update_modules(tree_unflatten(fused_linears))
|
if fused_linears:
|
||||||
|
model.update_modules(tree_unflatten(fused_linears))
|
||||||
|
|
||||||
if args.de_quantize:
|
if args.de_quantize:
|
||||||
print("De-quantizing model")
|
print("De-quantizing model")
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@ -12,7 +13,10 @@ DEFAULT_MAX_TOKENS = 100
|
|||||||
DEFAULT_TEMP = 0.6
|
DEFAULT_TEMP = 0.6
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
DEFAULT_MAX_KV_SIZE = 1024
|
|
||||||
|
|
||||||
|
def str2bool(string):
|
||||||
|
return string.lower() not in ["false", "f"]
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
@ -40,7 +44,9 @@ def setup_arg_parser():
|
|||||||
help="End of sequence token for tokenizer",
|
help="End of sequence token for tokenizer",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
"--prompt",
|
||||||
|
default=DEFAULT_PROMPT,
|
||||||
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
@ -66,6 +72,12 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the default chat template",
|
help="Use the default chat template",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--colorize",
|
"--colorize",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -81,6 +93,7 @@ def setup_arg_parser():
|
|||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-file",
|
"--kv-cache-file",
|
||||||
@ -171,14 +184,19 @@ def main():
|
|||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
elif tokenizer.chat_template is None:
|
elif cache_history is not None:
|
||||||
tokenizer.chat_template = metadata["chat_template"]
|
tokenizer.chat_template = metadata["chat_template"]
|
||||||
|
|
||||||
if not args.ignore_chat_template and (
|
if not args.ignore_chat_template and (
|
||||||
hasattr(tokenizer, "apply_chat_template")
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
and tokenizer.chat_template is not None
|
and tokenizer.chat_template is not None
|
||||||
):
|
):
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||||
|
}
|
||||||
|
]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
@ -195,29 +213,30 @@ def main():
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
|
if args.colorize and not args.verbose:
|
||||||
|
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
# Determine the max kv size from the kv cache or passed arguments
|
# Determine the max kv size from the kv cache or passed arguments
|
||||||
max_kv_size = args.max_kv_size
|
max_kv_size = args.max_kv_size
|
||||||
if max_kv_size is None:
|
if cache_history is not None:
|
||||||
max_kv_size = (
|
max_kv_size = metadata["max_kv_size"]
|
||||||
int(metadata["max_kv_size"])
|
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||||
if cache_history is not None
|
|
||||||
else DEFAULT_MAX_KV_SIZE
|
|
||||||
)
|
|
||||||
|
|
||||||
generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
verbose=True,
|
verbose=args.verbose,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=max_kv_size,
|
max_kv_size=max_kv_size,
|
||||||
cache_history=cache_history,
|
cache_history=cache_history,
|
||||||
)
|
)
|
||||||
|
if not args.verbose:
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -67,7 +67,7 @@ class HfVocab:
|
|||||||
def get_token_type(
|
def get_token_type(
|
||||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||||
) -> TokenType:
|
) -> TokenType:
|
||||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
|
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||||
return TokenType.BYTE
|
return TokenType.BYTE
|
||||||
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
||||||
|
|
||||||
@ -77,9 +77,7 @@ class HfVocab:
|
|||||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||||
for text in self.added_tokens_list:
|
for text in self.added_tokens_list:
|
||||||
if text in self.specials:
|
if text in self.specials:
|
||||||
toktype = self.get_token_type(
|
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
|
||||||
self.specials[text], b"", self.special_ids
|
|
||||||
)
|
|
||||||
score = self.get_token_score(self.specials[text])
|
score = self.get_token_score(self.specials[text])
|
||||||
else:
|
else:
|
||||||
toktype = TokenType.USER_DEFINED
|
toktype = TokenType.USER_DEFINED
|
||||||
@ -243,15 +241,18 @@ def prepare_metadata(config, vocab):
|
|||||||
metadata["tokenizer.ggml.tokens"] = tokens
|
metadata["tokenizer.ggml.tokens"] = tokens
|
||||||
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
|
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
|
||||||
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
|
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
|
||||||
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
|
if vocab.tokenizer.bos_token_id is not None:
|
||||||
vocab.tokenizer.bos_token_id, dtype=mx.uint32
|
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
|
||||||
)
|
vocab.tokenizer.bos_token_id, dtype=mx.uint32
|
||||||
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
|
)
|
||||||
vocab.tokenizer.eos_token_id, dtype=mx.uint32
|
if vocab.tokenizer.eos_token_id is not None:
|
||||||
)
|
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
|
||||||
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
|
vocab.tokenizer.eos_token_id, dtype=mx.uint32
|
||||||
vocab.tokenizer.unk_token_id, dtype=mx.uint32
|
)
|
||||||
)
|
if vocab.tokenizer.unk_token_id is not None:
|
||||||
|
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
|
||||||
|
vocab.tokenizer.unk_token_id, dtype=mx.uint32
|
||||||
|
)
|
||||||
|
|
||||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||||
return metadata
|
return metadata
|
||||||
|
@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
|
|||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import (
|
from .tuner.utils import (
|
||||||
apply_lora_layers,
|
|
||||||
build_schedule,
|
build_schedule,
|
||||||
linear_to_lora_layers,
|
linear_to_lora_layers,
|
||||||
|
load_adapters,
|
||||||
print_trainable_parameters,
|
print_trainable_parameters,
|
||||||
)
|
)
|
||||||
from .utils import load, save_config
|
from .utils import load, save_config
|
||||||
@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
|
|||||||
CONFIG_DEFAULTS = {
|
CONFIG_DEFAULTS = {
|
||||||
"model": "mlx_model",
|
"model": "mlx_model",
|
||||||
"train": False,
|
"train": False,
|
||||||
|
"fine_tune_type": "lora",
|
||||||
"data": "data/",
|
"data": "data/",
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"lora_layers": 16,
|
"num_layers": 16,
|
||||||
"batch_size": 4,
|
"batch_size": 4,
|
||||||
"iters": 1000,
|
"iters": 1000,
|
||||||
"val_batches": 25,
|
"val_batches": 25,
|
||||||
@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
|
|||||||
"max_seq_length": 2048,
|
"max_seq_length": 2048,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
"use_dora": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -79,10 +79,20 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
type=str,
|
type=str,
|
||||||
help="Directory with {train, valid, test}.jsonl files",
|
help=(
|
||||||
|
"Directory with {train, valid, test}.jsonl files or the name "
|
||||||
|
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-layers",
|
"--fine-tune-type",
|
||||||
|
type=str,
|
||||||
|
choices=["lora", "dora", "full"],
|
||||||
|
default="lora",
|
||||||
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||||
)
|
)
|
||||||
@ -107,12 +117,12 @@ def build_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume-adapter-file",
|
"--resume-adapter-file",
|
||||||
type=str,
|
type=str,
|
||||||
help="Load path to resume training with the given adapters.",
|
help="Load path to resume training from the given fine-tuned weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-path",
|
"--adapter-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Save/load path for the adapters.",
|
help="Save/load path for the fine-tuned weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every",
|
"--save-every",
|
||||||
@ -148,9 +158,6 @@ def build_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||||
parser.add_argument(
|
|
||||||
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -162,21 +169,31 @@ def train_model(
|
|||||||
valid_set,
|
valid_set,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
# Freeze all layers
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
|
if args.fine_tune_type == "full":
|
||||||
|
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||||
|
l.unfreeze()
|
||||||
|
elif args.fine_tune_type in ["lora", "dora"]:
|
||||||
|
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||||
|
linear_to_lora_layers(
|
||||||
|
model,
|
||||||
|
args.num_layers,
|
||||||
|
args.lora_parameters,
|
||||||
|
use_dora=(args.fine_tune_type == "dora"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
||||||
|
|
||||||
# Convert linear layers to lora layers and unfreeze in the process
|
# Resume from weights if provided
|
||||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
|
|
||||||
|
|
||||||
# Resume training the given adapters.
|
|
||||||
if args.resume_adapter_file is not None:
|
if args.resume_adapter_file is not None:
|
||||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
||||||
model.load_weights(args.resume_adapter_file, strict=False)
|
model.load_weights(args.resume_adapter_file, strict=False)
|
||||||
|
|
||||||
print_trainable_parameters(model)
|
print_trainable_parameters(model)
|
||||||
|
|
||||||
adapter_path = Path(args.adapter_path)
|
adapter_path = Path(args.adapter_path)
|
||||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
adapter_file = adapter_path / "adapters.safetensors"
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
@ -240,7 +257,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
if args.test and not args.train:
|
if args.test and not args.train:
|
||||||
# Allow testing without LoRA layers by providing empty path
|
# Allow testing without LoRA layers by providing empty path
|
||||||
if args.adapter_path != "":
|
if args.adapter_path != "":
|
||||||
apply_lora_layers(model, args.adapter_path)
|
load_adapters(model, args.adapter_path)
|
||||||
|
|
||||||
elif args.train:
|
elif args.train:
|
||||||
print("Training")
|
print("Training")
|
||||||
|
231
llms/mlx_lm/models/mamba.py
Normal file
231
llms/mlx_lm/models/mamba.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
state_size: int
|
||||||
|
num_hidden_layers: int
|
||||||
|
conv_kernel: int
|
||||||
|
use_bias: bool
|
||||||
|
use_conv_bias: bool
|
||||||
|
time_step_rank: int
|
||||||
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||||
|
self.hidden_size = self.d_model
|
||||||
|
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
|
||||||
|
self.intermediate_size = self.d_inner
|
||||||
|
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
|
||||||
|
self.state_size = self.d_state
|
||||||
|
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
|
||||||
|
self.num_hidden_layers = self.n_layer
|
||||||
|
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
|
||||||
|
self.num_hidden_layers = self.n_layers
|
||||||
|
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
|
||||||
|
self.conv_kernel = self.d_conv
|
||||||
|
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
|
||||||
|
self.use_bias = self.bias
|
||||||
|
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
|
||||||
|
self.use_conv_bias = self.conv_bias
|
||||||
|
|
||||||
|
if self.time_step_rank == "auto":
|
||||||
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
|
||||||
|
|
||||||
|
class MambaCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = [None, None]
|
||||||
|
|
||||||
|
def __setitem__(self, idx, value):
|
||||||
|
self.cache[idx] = value
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.cache[idx]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self.cache
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseConv1d(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.padding = padding
|
||||||
|
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
||||||
|
self.bias = mx.zeros((channels,)) if bias else None
|
||||||
|
|
||||||
|
def __call__(self, x, cache=None):
|
||||||
|
B, L, C = x.shape
|
||||||
|
groups, K, _ = self.weight.shape
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
x = mx.concatenate([cache, x], axis=1)
|
||||||
|
else:
|
||||||
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||||
|
|
||||||
|
y = mx.conv_general(x, self.weight, groups=groups)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
y = y + self.bias
|
||||||
|
|
||||||
|
return y, x[:, -K + 1 :, :]
|
||||||
|
|
||||||
|
|
||||||
|
class MambaBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.ssm_state_size = args.state_size
|
||||||
|
self.conv_kernel_size = args.conv_kernel
|
||||||
|
self.intermediate_size = args.intermediate_size
|
||||||
|
self.time_step_rank = int(args.time_step_rank)
|
||||||
|
self.use_conv_bias = args.use_conv_bias
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(
|
||||||
|
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv1d = DepthWiseConv1d(
|
||||||
|
channels=self.intermediate_size,
|
||||||
|
kernel_size=self.conv_kernel_size,
|
||||||
|
bias=self.use_conv_bias,
|
||||||
|
padding=self.conv_kernel_size - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.x_proj = nn.Linear(
|
||||||
|
self.intermediate_size,
|
||||||
|
self.time_step_rank + 2 * self.ssm_state_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
|
||||||
|
|
||||||
|
A = mx.repeat(
|
||||||
|
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
|
||||||
|
repeats=self.intermediate_size,
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
self.A_log = mx.log(A)
|
||||||
|
self.D = mx.ones([self.intermediate_size])
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(
|
||||||
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def ssm_step(self, x, state=None):
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
|
D = self.D
|
||||||
|
deltaBC = self.x_proj(x)
|
||||||
|
delta, B, C = mx.split(
|
||||||
|
deltaBC,
|
||||||
|
indices_or_sections=[
|
||||||
|
self.time_step_rank,
|
||||||
|
self.time_step_rank + self.ssm_state_size,
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||||
|
if state is not None:
|
||||||
|
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||||
|
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||||
|
y = y + D * x
|
||||||
|
return y, new_state
|
||||||
|
|
||||||
|
def __call__(self, x, cache):
|
||||||
|
B, T, D = x.shape
|
||||||
|
if cache is None:
|
||||||
|
cache = [None, None]
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for t in range(T):
|
||||||
|
xt = x[:, t, :]
|
||||||
|
xz = self.in_proj(xt)
|
||||||
|
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
||||||
|
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||||
|
x_t = conv_out.squeeze(1)
|
||||||
|
x_t = nn.silu(x_t)
|
||||||
|
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
||||||
|
z_t = nn.silu(z_t)
|
||||||
|
output_t = y_t * z_t
|
||||||
|
output_t = self.out_proj(output_t)
|
||||||
|
outputs.append(output_t)
|
||||||
|
output = mx.stack(outputs, axis=1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.mixer = MambaBlock(args)
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, cache):
|
||||||
|
return self.mixer(self.norm(x), cache) + x
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
||||||
|
self.norm_f = nn.RMSNorm(args.hidden_size)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, cache):
|
||||||
|
x = self.embeddings(x)
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
x = layer(x, c)
|
||||||
|
return self.norm_f(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.backbone = Mamba(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, inputs: mx.array, cache=None):
|
||||||
|
B, T = inputs.shape
|
||||||
|
|
||||||
|
x = self.backbone(inputs, cache)
|
||||||
|
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
logits = self.backbone.embeddings.as_linear(x)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "conv1d.weight" in k and v.ndim == 3:
|
||||||
|
weights[k] = v.moveaxis(2, 1)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def make_cache(self, batch_size: int = 1):
|
||||||
|
return [MambaCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.backbone.layers
|
227
llms/mlx_lm/models/nemotron.py
Normal file
227
llms/mlx_lm/models/nemotron.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs(BaseModelArgs):
|
||||||
|
model_type: str
|
||||||
|
hidden_size: int
|
||||||
|
hidden_act: str
|
||||||
|
num_hidden_layers: int
|
||||||
|
intermediate_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
norm_eps: float
|
||||||
|
vocab_size: int
|
||||||
|
num_key_value_heads: int
|
||||||
|
head_dim: Optional[int] = None
|
||||||
|
max_position_embeddings: Optional[int] = None
|
||||||
|
attention_bias: bool = False
|
||||||
|
mlp_bias: bool = False
|
||||||
|
partial_rotary_factor: float = 0.5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.rope_scaling:
|
||||||
|
if not "factor" in self.rope_scaling:
|
||||||
|
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||||
|
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
||||||
|
"rope_type"
|
||||||
|
)
|
||||||
|
if rope_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"rope_scaling must contain either 'type' or 'rope_type'"
|
||||||
|
)
|
||||||
|
if rope_type not in ["linear"]:
|
||||||
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu_squared(x):
|
||||||
|
return nn.relu(x).square()
|
||||||
|
|
||||||
|
|
||||||
|
class NemotronLayerNorm1P(nn.LayerNorm):
|
||||||
|
def __call__(self, x):
|
||||||
|
weight = self.weight + 1 if "weight" in self else None
|
||||||
|
bias = self.bias if "bias" in self else None
|
||||||
|
return mx.fast.layer_norm(x, weight, bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
|
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||||
|
self.partial_rotary_factor = args.partial_rotary_factor
|
||||||
|
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
if hasattr(args, "attention_bias"):
|
||||||
|
attention_bias = args.attention_bias
|
||||||
|
else:
|
||||||
|
attention_bias = False
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
|
rope_scale = 1.0
|
||||||
|
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||||
|
assert isinstance(args.rope_scaling["factor"], float)
|
||||||
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
int(self.partial_rotary_factor * self.head_dim),
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[KVCache] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, _ = x.shape
|
||||||
|
|
||||||
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
|
# Prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rope(queries)
|
||||||
|
keys = self.rope(keys)
|
||||||
|
|
||||||
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim = args.hidden_size
|
||||||
|
hidden_dim = args.intermediate_size
|
||||||
|
mlp_bias = args.mlp_bias
|
||||||
|
|
||||||
|
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||||
|
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||||
|
|
||||||
|
def __call__(self, x) -> mx.array:
|
||||||
|
return self.down_proj(relu_squared(self.up_proj(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = args.num_attention_heads
|
||||||
|
self.hidden_size = args.hidden_size
|
||||||
|
self.self_attn = Attention(args)
|
||||||
|
self.mlp = MLP(args)
|
||||||
|
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||||
|
self.post_attention_layernorm = NemotronLayerNorm1P(
|
||||||
|
args.hidden_size, eps=args.norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
|
cache: Optional[KVCache] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NemotronModel(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
self.num_hidden_layers = args.num_hidden_layers
|
||||||
|
assert self.vocab_size > 0
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||||
|
]
|
||||||
|
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model_type = args.model_type
|
||||||
|
self.model = NemotronModel(args)
|
||||||
|
if not args.tie_word_embeddings:
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
inputs: mx.array,
|
||||||
|
cache=None,
|
||||||
|
):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
if self.args.tie_word_embeddings:
|
||||||
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
|
else:
|
||||||
|
out = self.lm_head(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return (
|
||||||
|
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
@ -11,6 +11,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from huggingface_hub import scan_cache_dir
|
||||||
|
|
||||||
from .utils import generate_step, load
|
from .utils import generate_step, load
|
||||||
|
|
||||||
@ -596,6 +597,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
):
|
):
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
body["messages"],
|
body["messages"],
|
||||||
|
body.get("tools", None),
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
@ -621,6 +623,46 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
prompt = self.tokenizer.encode(prompt_text)
|
prompt = self.tokenizer.encode(prompt_text)
|
||||||
return mx.array(prompt)
|
return mx.array(prompt)
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
"""
|
||||||
|
Respond to a GET request from a client.
|
||||||
|
"""
|
||||||
|
if self.path == "/v1/models":
|
||||||
|
self.handle_models_request()
|
||||||
|
else:
|
||||||
|
self._set_completion_headers(404)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"Not Found")
|
||||||
|
|
||||||
|
def handle_models_request(self):
|
||||||
|
"""
|
||||||
|
Handle a GET request for the /v1/models endpoint.
|
||||||
|
"""
|
||||||
|
self._set_completion_headers(200)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
# Scan the cache directory for downloaded mlx models
|
||||||
|
hf_cache_info = scan_cache_dir()
|
||||||
|
downloaded_models = [
|
||||||
|
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a list of available models
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": repo.repo_id,
|
||||||
|
"object": "model",
|
||||||
|
"created": self.created,
|
||||||
|
}
|
||||||
|
for repo in downloaded_models
|
||||||
|
]
|
||||||
|
|
||||||
|
response = {"object": "list", "data": models}
|
||||||
|
|
||||||
|
response_json = json.dumps(response).encode()
|
||||||
|
self.wfile.write(response_json)
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
host: str,
|
host: str,
|
||||||
|
@ -36,7 +36,10 @@ class ChatDataset(Dataset):
|
|||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
messages = self._data[idx]["messages"]
|
messages = self._data[idx]["messages"]
|
||||||
text = self._tokenizer.apply_chat_template(
|
text = self._tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages,
|
||||||
|
tools=self._data[idx].get("tools", None),
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@ -73,17 +76,14 @@ class CompletionsDataset(Dataset):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
||||||
# Return empty dataset for non-existent paths
|
sample = data[0]
|
||||||
if not path.exists():
|
|
||||||
return []
|
if "messages" in sample:
|
||||||
with open(path, "r") as fid:
|
|
||||||
data = [json.loads(l) for l in fid]
|
|
||||||
if "messages" in data[0]:
|
|
||||||
return ChatDataset(data, tokenizer)
|
return ChatDataset(data, tokenizer)
|
||||||
elif "prompt" in data[0] and "completion" in data[0]:
|
elif "prompt" in sample and "completion" in sample:
|
||||||
return CompletionsDataset(data, tokenizer)
|
return CompletionsDataset(data, tokenizer)
|
||||||
elif "text" in data[0]:
|
elif "text" in sample:
|
||||||
return Dataset(data)
|
return Dataset(data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -92,54 +92,90 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
|
||||||
if getattr(args, "hf_dataset", None) is not None:
|
def load_subset(path):
|
||||||
import datasets
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
with open(path, "r") as fid:
|
||||||
|
data = [json.loads(l) for l in fid]
|
||||||
|
return create_dataset(data, tokenizer)
|
||||||
|
|
||||||
hf_args = args.hf_dataset
|
names = ("train", "valid", "test")
|
||||||
dataset_name = hf_args["name"]
|
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
|
||||||
print(f"Loading Hugging Face dataset {dataset_name}.")
|
return train, valid, test
|
||||||
text_feature = hf_args.get("text_feature")
|
|
||||||
prompt_feature = hf_args.get("prompt_feature")
|
|
||||||
completion_feature = hf_args.get("completion_feature")
|
|
||||||
|
|
||||||
def create_hf_dataset(split: str = None):
|
|
||||||
ds = datasets.load_dataset(
|
|
||||||
dataset_name,
|
|
||||||
split=split,
|
|
||||||
**hf_args.get("config", {}),
|
|
||||||
)
|
|
||||||
if prompt_feature and completion_feature:
|
|
||||||
return CompletionsDataset(
|
|
||||||
ds, tokenizer, prompt_feature, completion_feature
|
|
||||||
)
|
|
||||||
elif text_feature:
|
|
||||||
return Dataset(train_ds, text_key=text_feature)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Specify either a prompt and completion feature or a text "
|
|
||||||
"feature for the Hugging Face dataset."
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.train:
|
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
|
||||||
train_split = hf_args.get("train_split", "train[:80%]")
|
from datasets import exceptions, load_dataset
|
||||||
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
|
||||||
train = create_hf_dataset(split=train_split)
|
try:
|
||||||
valid = create_hf_dataset(split=valid_split)
|
dataset = load_dataset(data_id)
|
||||||
else:
|
|
||||||
train, valid = [], []
|
|
||||||
if args.test:
|
|
||||||
test = create_hf_dataset(split=hf_args.get("test_split"))
|
|
||||||
else:
|
|
||||||
test = []
|
|
||||||
|
|
||||||
else:
|
|
||||||
names = ("train", "valid", "test")
|
names = ("train", "valid", "test")
|
||||||
data_path = Path(args.data)
|
|
||||||
|
|
||||||
train, valid, test = [
|
train, valid, test = [
|
||||||
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
|
||||||
|
for n in names
|
||||||
]
|
]
|
||||||
|
|
||||||
|
except exceptions.DatasetNotFoundError:
|
||||||
|
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
|
||||||
|
|
||||||
|
return train, valid, test
|
||||||
|
|
||||||
|
|
||||||
|
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
hf_args = args.hf_dataset
|
||||||
|
dataset_name = hf_args["name"]
|
||||||
|
print(f"Loading Hugging Face dataset {dataset_name}.")
|
||||||
|
text_feature = hf_args.get("text_feature")
|
||||||
|
prompt_feature = hf_args.get("prompt_feature")
|
||||||
|
completion_feature = hf_args.get("completion_feature")
|
||||||
|
|
||||||
|
def create_hf_dataset(split: str = None):
|
||||||
|
ds = datasets.load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=split,
|
||||||
|
**hf_args.get("config", {}),
|
||||||
|
)
|
||||||
|
if prompt_feature and completion_feature:
|
||||||
|
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
||||||
|
elif text_feature:
|
||||||
|
return Dataset(train_ds, text_key=text_feature)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Specify either a prompt and completion feature or a text "
|
||||||
|
"feature for the Hugging Face dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.train:
|
||||||
|
train_split = hf_args.get("train_split", "train[:80%]")
|
||||||
|
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
||||||
|
train = create_hf_dataset(split=train_split)
|
||||||
|
valid = create_hf_dataset(split=valid_split)
|
||||||
|
else:
|
||||||
|
train, valid = [], []
|
||||||
|
if args.test:
|
||||||
|
test = create_hf_dataset(split=hf_args.get("test_split"))
|
||||||
|
else:
|
||||||
|
test = []
|
||||||
|
|
||||||
|
return train, valid, test
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
|
if getattr(args, "hf_dataset", None) is not None:
|
||||||
|
train, valid, test = load_custom_hf_dataset(args, tokenizer)
|
||||||
|
else:
|
||||||
|
data_path = Path(args.data)
|
||||||
|
if data_path.exists():
|
||||||
|
train, valid, test = load_local_dataset(data_path, tokenizer)
|
||||||
|
else:
|
||||||
|
print(f"Loading Hugging Face dataset {args.data}.")
|
||||||
|
train, valid, test = load_hf_dataset(args.data, tokenizer)
|
||||||
|
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||||
|
@ -14,10 +14,11 @@ class DoRALinear(nn.Module):
|
|||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 20.0,
|
scale: float = 20.0,
|
||||||
):
|
):
|
||||||
# TODO support quantized weights in DoRALinear
|
# TODO remove when input_dims and output_dims are attributes
|
||||||
|
# on linear and quantized linear
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
if isinstance(linear, nn.QuantizedLinear):
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
raise ValueError("DoRALinear does not yet support quantization.")
|
input_dims *= 32 // linear.bits
|
||||||
dora_lin = DoRALinear(
|
dora_lin = DoRALinear(
|
||||||
input_dims=input_dims,
|
input_dims=input_dims,
|
||||||
output_dims=output_dims,
|
output_dims=output_dims,
|
||||||
@ -31,13 +32,13 @@ class DoRALinear(nn.Module):
|
|||||||
def fuse(self, de_quantize: bool = False):
|
def fuse(self, de_quantize: bool = False):
|
||||||
linear = self.linear
|
linear = self.linear
|
||||||
bias = "bias" in linear
|
bias = "bias" in linear
|
||||||
weight = linear.weight
|
weight = self._dequantized_weight()
|
||||||
|
|
||||||
# Use the same type as the linear weight if not quantized
|
# Use the same type as the linear weight
|
||||||
dtype = weight.dtype
|
dtype = weight.dtype
|
||||||
|
|
||||||
output_dims, input_dims = weight.shape
|
output_dims, input_dims = weight.shape
|
||||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
|
||||||
|
|
||||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||||
lora_a = self.lora_a.T.astype(dtype)
|
lora_a = self.lora_a.T.astype(dtype)
|
||||||
@ -47,6 +48,13 @@ class DoRALinear(nn.Module):
|
|||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
fused_linear.bias = linear.bias
|
fused_linear.bias = linear.bias
|
||||||
|
|
||||||
|
if self._is_quantized() and not de_quantize:
|
||||||
|
fused_linear = nn.QuantizedLinear.from_linear(
|
||||||
|
fused_linear,
|
||||||
|
linear.group_size,
|
||||||
|
linear.bits,
|
||||||
|
)
|
||||||
return fused_linear
|
return fused_linear
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -76,22 +84,45 @@ class DoRALinear(nn.Module):
|
|||||||
)
|
)
|
||||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||||
|
|
||||||
def set_linear(self, linear: nn.Linear):
|
def set_linear(self, linear):
|
||||||
|
"""
|
||||||
|
Set the self.linear layer and recompute self.m.
|
||||||
|
"""
|
||||||
self.linear = linear
|
self.linear = linear
|
||||||
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
|
||||||
|
|
||||||
|
def _dequantized_weight(self):
|
||||||
|
"""
|
||||||
|
Return the weight of linear layer and dequantize it if is quantized
|
||||||
|
"""
|
||||||
|
weight = self.linear.weight
|
||||||
|
if self._is_quantized():
|
||||||
|
weight = mx.dequantize(
|
||||||
|
weight,
|
||||||
|
self.linear.scales,
|
||||||
|
self.linear.biases,
|
||||||
|
self.linear.group_size,
|
||||||
|
self.linear.bits,
|
||||||
|
)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _is_quantized(self):
|
||||||
|
return isinstance(self.linear, nn.QuantizedLinear)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
# Regular LoRA (without a bias)
|
# Regular LoRA (without a bias)
|
||||||
y = x @ self.linear.weight.T
|
w = self._dequantized_weight()
|
||||||
|
y = x @ w.T
|
||||||
|
|
||||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||||
out = y + (self.scale * z).astype(x.dtype)
|
out = y + (self.scale * z).astype(x.dtype)
|
||||||
|
|
||||||
# Compute the norm of the adapted weights
|
# Compute the norm of the adapted weights
|
||||||
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||||
|
|
||||||
# Remove the norm and scale by the learned magnitude
|
# Remove the norm and scale by the learned magnitude
|
||||||
out = (self.m / denom) * out
|
out = (self.m / denom).astype(x.dtype) * out
|
||||||
|
|
||||||
if "bias" in self.linear:
|
if "bias" in self.linear:
|
||||||
out = out + self.linear.bias
|
out = out + self.linear.bias
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -93,9 +95,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
# Encode batch
|
# Encode batch
|
||||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||||
for b in batch:
|
for b in batch:
|
||||||
if b[-1] == tokenizer.eos_token_id:
|
if b[-1] != tokenizer.eos_token_id:
|
||||||
print("[WARNING] Example already has an EOS token appended")
|
|
||||||
else:
|
|
||||||
b.append(tokenizer.eos_token_id)
|
b.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
@ -287,24 +287,18 @@ def train(
|
|||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
if it % args.steps_per_save == 0:
|
if it % args.steps_per_save == 0:
|
||||||
save_adapter(model, args.adapter_file)
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
checkpoint = (
|
checkpoint = (
|
||||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||||
)
|
)
|
||||||
save_adapter(model, checkpoint)
|
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: Saved adapter weights to "
|
f"Iter {it}: Saved adapter weights to "
|
||||||
f"{args.adapter_file} and {checkpoint}."
|
f"{args.adapter_file} and {checkpoint}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# save final adapter weights
|
# Save final weights
|
||||||
save_adapter(model, args.adapter_file)
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||||
print(f"Saved final adapter weights to {args.adapter_file}.")
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||||
|
print(f"Saved final weights to {args.adapter_file}.")
|
||||||
|
|
||||||
def save_adapter(
|
|
||||||
model: nn.Module,
|
|
||||||
adapter_file: Union[str, Path],
|
|
||||||
):
|
|
||||||
flattened_tree = tree_flatten(model.trainable_parameters())
|
|
||||||
mx.save_safetensors(str(adapter_file), dict(flattened_tree))
|
|
||||||
|
@ -36,7 +36,7 @@ def build_schedule(schedule_config: Dict):
|
|||||||
|
|
||||||
def linear_to_lora_layers(
|
def linear_to_lora_layers(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
num_lora_layers: int,
|
num_layers: int,
|
||||||
config: Dict,
|
config: Dict,
|
||||||
use_dora: bool = False,
|
use_dora: bool = False,
|
||||||
):
|
):
|
||||||
@ -45,23 +45,17 @@ def linear_to_lora_layers(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The neural network model.
|
model (nn.Module): The neural network model.
|
||||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
num_layers (int): The number of blocks to convert to lora layers
|
||||||
starting from the last layer.
|
starting from the last layer.
|
||||||
config (dict): More configuration parameters for LoRA, including the
|
config (dict): More configuration parameters for LoRA, including the
|
||||||
rank, scale, and optional layer keys.
|
rank, scale, and optional layer keys.
|
||||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||||
Default: ``False``
|
Default: ``False``
|
||||||
"""
|
"""
|
||||||
|
if num_layers > len(model.layers):
|
||||||
num_layers = len(model.layers)
|
|
||||||
|
|
||||||
if num_lora_layers < 0:
|
|
||||||
num_lora_layers = num_layers
|
|
||||||
|
|
||||||
if num_lora_layers > num_layers:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested {num_lora_layers} LoRA layers "
|
f"Requested {num_layers} LoRA layers "
|
||||||
f"but the model only has {num_layers} layers."
|
f"but the model only has {len(model.layers)} layers."
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_lora(layer):
|
def to_lora(layer):
|
||||||
@ -93,6 +87,7 @@ def linear_to_lora_layers(
|
|||||||
"llama",
|
"llama",
|
||||||
"phi",
|
"phi",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
|
"nemotron",
|
||||||
"stablelm",
|
"stablelm",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
@ -139,10 +134,19 @@ def linear_to_lora_layers(
|
|||||||
"self_attn.kv_b_proj",
|
"self_attn.kv_b_proj",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
elif model.model_type == "mamba":
|
||||||
|
keys = set(
|
||||||
|
[
|
||||||
|
"mixer.in_proj",
|
||||||
|
"mixer.x_proj",
|
||||||
|
"mixer.dt_proj",
|
||||||
|
"mixer.out_proj",
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
for l in model.layers[num_layers - num_lora_layers :]:
|
for l in model.layers[-min(num_layers, 0) :]:
|
||||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||||
if lora_layers:
|
if lora_layers:
|
||||||
l.update_modules(tree_unflatten(lora_layers))
|
l.update_modules(tree_unflatten(lora_layers))
|
||||||
@ -152,9 +156,9 @@ def linear_to_lora_layers(
|
|||||||
model.update_modules(tree_unflatten(lora_modules))
|
model.update_modules(tree_unflatten(lora_modules))
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Apply LoRA layers to the model.
|
Load any fine-tuned adapters / layers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The neural network model.
|
model (nn.Module): The neural network model.
|
||||||
@ -168,12 +172,14 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
|||||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||||
config = types.SimpleNamespace(**json.load(fid))
|
config = types.SimpleNamespace(**json.load(fid))
|
||||||
linear_to_lora_layers(
|
fine_tune_type = getattr(config, "fine_tune_type", "lora")
|
||||||
model,
|
if fine_tune_type != "full":
|
||||||
config.lora_layers,
|
linear_to_lora_layers(
|
||||||
config.lora_parameters,
|
model,
|
||||||
getattr(config, "use_dora", False),
|
config.num_layers,
|
||||||
)
|
config.lora_parameters,
|
||||||
|
use_dora=(fine_tune_type == "dora"),
|
||||||
|
)
|
||||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -22,8 +21,8 @@ from transformers import PreTrainedTokenizer
|
|||||||
from .models.base import KVCache, RotatingKVCache
|
from .models.base import KVCache, RotatingKVCache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import apply_lora_layers
|
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
|
from .tuner.utils import load_adapters
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
@ -91,7 +90,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except RepositoryNotFoundError:
|
except:
|
||||||
raise ModelNotFoundError(
|
raise ModelNotFoundError(
|
||||||
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
||||||
"Please make sure you specified the local path or Hugging Face"
|
"Please make sure you specified the local path or Hugging Face"
|
||||||
@ -102,7 +101,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
||||||
"""
|
"""
|
||||||
Apply repetition penalty to specific logits based on the given context.
|
Apply repetition penalty to specific logits based on the given context.
|
||||||
|
|
||||||
@ -110,19 +109,18 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (mx.array): The logits produced by the language model.
|
logits (mx.array): The logits produced by the language model.
|
||||||
generated_tokens (any): A list of N previous tokens.
|
tokens (mx.array): A list of N previous tokens.
|
||||||
penalty (float): The repetition penalty factor to be applied.
|
penalty (float): The repetition penalty factor to be applied.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||||
"""
|
"""
|
||||||
if len(generated_tokens) > 0:
|
if len(tokens) > 0:
|
||||||
indices = generated_tokens
|
selected_logits = mx.take_along_axis(logits, tokens, axis=-1)
|
||||||
selected_logits = mx.take_along_axis(logits, indices, axis=-1)
|
|
||||||
selected_logits = mx.where(
|
selected_logits = mx.where(
|
||||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||||
)
|
)
|
||||||
logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits
|
logits[mx.arange(tokens.shape[0])[:, None], tokens] = selected_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@ -155,16 +153,17 @@ def generate_step(
|
|||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
min_p: float = 0.0,
|
min_p: float = 0.0,
|
||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
|
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``.
|
prompts (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
Default: ``0``.
|
Default: ``0``.
|
||||||
@ -178,10 +177,13 @@ def generate_step(
|
|||||||
probability) that a token probability must have to be considered.
|
probability) that a token probability must have to be considered.
|
||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
be filtered by min_p sampling.
|
be filtered by min_p sampling.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
|
A list of functions that take tokens and logits and return the processed
|
||||||
|
logits. Default: ``None``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
@ -195,10 +197,6 @@ def generate_step(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
|
def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
if logit_bias:
|
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
|
||||||
values = mx.array(list(logit_bias.values()))
|
|
||||||
logits[:, indices] += values
|
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||||
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
@ -220,7 +218,29 @@ def generate_step(
|
|||||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logits_processor = logits_processor or []
|
||||||
|
|
||||||
|
if repetition_penalty:
|
||||||
|
|
||||||
|
def repetition_penalty_processor(tokens, logits):
|
||||||
|
return apply_repetition_penalty(
|
||||||
|
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processor.append(repetition_penalty_processor)
|
||||||
|
|
||||||
|
if logit_bias:
|
||||||
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
|
values = mx.array(list(logit_bias.values()))
|
||||||
|
|
||||||
|
def logit_bias_processor(_, logits):
|
||||||
|
logits[:, indices] += values
|
||||||
|
return logits
|
||||||
|
|
||||||
|
logits_processor.append(logit_bias_processor)
|
||||||
|
|
||||||
y = prompts
|
y = prompts
|
||||||
|
tokens = None
|
||||||
|
|
||||||
# Create the KV cache for generation
|
# Create the KV cache for generation
|
||||||
cache = make_kv_caches(model, max_kv_size)
|
cache = make_kv_caches(model, max_kv_size)
|
||||||
@ -235,28 +255,18 @@ def generate_step(
|
|||||||
c.update_and_fetch(h[0], h[1])
|
c.update_and_fetch(h[0], h[1])
|
||||||
mx.eval([c.state for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
|
|
||||||
repetition_context = prompts
|
|
||||||
|
|
||||||
if repetition_context_size:
|
|
||||||
repetition_context = repetition_context[:, -repetition_context_size:]
|
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
nonlocal repetition_context
|
|
||||||
logits = model(y, cache=cache)
|
logits = model(y, cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
if repetition_penalty:
|
if logits_processor:
|
||||||
logits = apply_repetition_penalty(
|
nonlocal tokens
|
||||||
logits, repetition_context, repetition_penalty
|
tokens = mx.concat([tokens, y], axis=-1) if tokens is not None else y
|
||||||
)
|
|
||||||
y, logprobs = sample(logits)
|
|
||||||
repetition_context = mx.concatenate([repetition_context, y], axis=-1)
|
|
||||||
else:
|
|
||||||
y, logprobs = sample(logits)
|
|
||||||
|
|
||||||
if repetition_context_size:
|
for processor in logits_processor:
|
||||||
if repetition_context.shape[1] > repetition_context_size:
|
logits = processor(tokens, logits)
|
||||||
repetition_context = repetition_context[:, -repetition_context_size:]
|
|
||||||
|
y, logprobs = sample(logits)
|
||||||
return y, logprobs
|
return y, logprobs
|
||||||
|
|
||||||
while y.shape[1] > prefill_step_size:
|
while y.shape[1] > prefill_step_size:
|
||||||
@ -265,6 +275,7 @@ def generate_step(
|
|||||||
y = y[:, prefill_step_size:]
|
y = y[:, prefill_step_size:]
|
||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y)
|
mx.async_eval(y)
|
||||||
while True:
|
while True:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
@ -280,7 +291,7 @@ def stream_generate(
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
@ -320,19 +331,19 @@ def stream_generate(
|
|||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: Union[str, List[str]],
|
prompt: str,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, List[str]]:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a complete response from the model.
|
Generate a complete response from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompts (str): The string prompt(s).
|
prompt (str): The string prompt.
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
@ -341,30 +352,98 @@ def generate(
|
|||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
"""
|
"""
|
||||||
is_batch = isinstance(prompt, list)
|
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
if is_batch:
|
if verbose:
|
||||||
tokenizer._tokenizer.padding_side = "left"
|
print("=" * 10)
|
||||||
if tokenizer.pad_token is None:
|
print("Prompt:", prompt)
|
||||||
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
prompt_tokens = mx.array(
|
detokenizer = tokenizer.detokenizer
|
||||||
tokenizer._tokenizer(prompt, padding=True)["input_ids"]
|
|
||||||
)
|
tic = time.perf_counter()
|
||||||
output_toks = []
|
detokenizer.reset()
|
||||||
else:
|
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))[None]
|
for (token, logprobs), n in zip(
|
||||||
detokenizer = tokenizer.detokenizer
|
generate_step(prompt_tokens[None], model, **kwargs),
|
||||||
detokenizer.reset()
|
range(max_tokens),
|
||||||
|
):
|
||||||
|
token = token.item()
|
||||||
|
if n == 0:
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
tic = time.perf_counter()
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=" * 10)
|
if formatter:
|
||||||
print("Prompt:", prompt)
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
|
detokenizer.finalize()
|
||||||
|
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||||
|
else:
|
||||||
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
|
|
||||||
|
token_count = n + 1
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
gen_time = time.perf_counter() - tic
|
||||||
|
print(detokenizer.last_segment, flush=True)
|
||||||
|
print("=" * 10)
|
||||||
|
if token_count == 0:
|
||||||
|
print("No tokens generated for this prompt")
|
||||||
|
return
|
||||||
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
|
gen_tps = (token_count - 1) / gen_time
|
||||||
|
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||||
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
|
return detokenizer.text
|
||||||
|
|
||||||
|
|
||||||
|
def batch_generate(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int = 100,
|
||||||
|
verbose: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a complete response from the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompts (List[str]): The string prompts.
|
||||||
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
|
Default: ``False``.
|
||||||
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
|
See :func:`generate_step` for more details.
|
||||||
|
"""
|
||||||
|
if kwargs.get("max_kv_size", None) is not None:
|
||||||
|
raise ValueError("max_kv_size is not supported for batch generation")
|
||||||
|
|
||||||
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
|
tokenizer._tokenizer.padding_side = "left"
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
prompt_tokens = mx.array(
|
||||||
|
tokenizer._tokenizer(prompts, padding=True)["input_ids"]
|
||||||
|
)
|
||||||
|
output_toks = []
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
for (tokens, logprobs), n in zip(
|
for (tokens, _), n in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
@ -373,51 +452,34 @@ def generate(
|
|||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
if (tokens == tokenizer.eos_token_id).all():
|
if (tokens == tokenizer.eos_token_id).all():
|
||||||
break
|
break
|
||||||
if is_batch:
|
output_toks.append(tokens)
|
||||||
output_toks.append(tokens)
|
if verbose:
|
||||||
if verbose:
|
print(".", end="", flush=True)
|
||||||
print(".", end="", flush=True)
|
|
||||||
else:
|
|
||||||
token = tokens.item()
|
|
||||||
logprobs = logprobs.squeeze(0)
|
|
||||||
detokenizer.add_token(token)
|
|
||||||
if verbose:
|
|
||||||
if formatter:
|
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
|
||||||
detokenizer.finalize()
|
|
||||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
|
||||||
else:
|
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
|
||||||
|
|
||||||
if is_batch:
|
output_toks = mx.concatenate(output_toks, axis=1)
|
||||||
output_toks = mx.concatenate(output_toks, axis=1)
|
token_count = output_toks.size
|
||||||
token_count = output_toks.size
|
response = [
|
||||||
response = [
|
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
||||||
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
for response in tokenizer.batch_decode(output_toks.tolist())
|
||||||
for response in tokenizer.batch_decode(output_toks.tolist())
|
]
|
||||||
]
|
|
||||||
else:
|
|
||||||
token_count = n
|
|
||||||
detokenizer.finalize()
|
|
||||||
response = detokenizer.text
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
if token_count <= 0:
|
if token_count <= 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
if is_batch:
|
else:
|
||||||
print()
|
print()
|
||||||
for p, resp in zip(prompt, response):
|
for p, resp in zip(prompts, response):
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print("Prompt:", p)
|
print("Prompt:", p)
|
||||||
print(resp)
|
print(resp)
|
||||||
else:
|
|
||||||
print(detokenizer.last_segment, flush=True)
|
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
gen_tps = token_count / gen_time
|
gen_tps = token_count / gen_time
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -539,7 +601,7 @@ def load(
|
|||||||
|
|
||||||
model = load_model(model_path, lazy, model_config)
|
model = load_model(model_path, lazy, model_config)
|
||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = apply_lora_layers(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
||||||
|
|
||||||
@ -596,6 +658,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
|
|
||||||
card = ModelCard.load(hf_path)
|
card = ModelCard.load(hf_path)
|
||||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||||
|
card.data.base_model = hf_path
|
||||||
card.text = dedent(
|
card.text = dedent(
|
||||||
f"""
|
f"""
|
||||||
# {upload_repo}
|
# {upload_repo}
|
||||||
@ -612,7 +675,16 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
from mlx_lm import load, generate
|
from mlx_lm import load, generate
|
||||||
|
|
||||||
model, tokenizer = load("{upload_repo}")
|
model, tokenizer = load("{upload_repo}")
|
||||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
|
||||||
|
prompt="hello"
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
||||||
|
messages = [{{"role": "user", "content": prompt}}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -702,6 +774,8 @@ def quantize_model(
|
|||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
nn.quantize(model, q_group_size, q_bits)
|
nn.quantize(model, q_group_size, q_bits)
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||||
|
# support hf model tree #957
|
||||||
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
@ -10,7 +10,7 @@ with open(package_dir / "requirements.txt") as fid:
|
|||||||
requirements = [l.strip() for l in fid.readlines()]
|
requirements = [l.strip() for l in fid.readlines()]
|
||||||
|
|
||||||
sys.path.append(str(package_dir))
|
sys.path.append(str(package_dir))
|
||||||
from version import __version__
|
from _version import __version__
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-lm",
|
name="mlx-lm",
|
||||||
|
@ -11,7 +11,7 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import lora, tuner
|
from mlx_lm import lora, tuner
|
||||||
from mlx_lm.tuner.dora import DoRAEmbedding
|
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
|
||||||
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
||||||
from mlx_lm.tuner.trainer import evaluate
|
from mlx_lm.tuner.trainer import evaluate
|
||||||
from mlx_lm.tuner.utils import build_schedule
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
@ -164,6 +164,147 @@ class TestDora(unittest.TestCase):
|
|||||||
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
||||||
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
||||||
|
|
||||||
|
def test_llama(self):
|
||||||
|
from mlx_lm.models import llama
|
||||||
|
|
||||||
|
hidden_size = 1024
|
||||||
|
intermediate_size = 2048
|
||||||
|
args = llama.ModelArgs(
|
||||||
|
model_type="llama",
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
dora_layers = 4
|
||||||
|
|
||||||
|
def check_config(params):
|
||||||
|
n_keys = 2
|
||||||
|
if "keys" in params:
|
||||||
|
n_keys = len(params["keys"])
|
||||||
|
model = llama.Model(args)
|
||||||
|
model.freeze()
|
||||||
|
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
|
||||||
|
trainable_params = sum(
|
||||||
|
v.size for _, v in tree_flatten(model.trainable_parameters())
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
trainable_params,
|
||||||
|
dora_layers
|
||||||
|
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||||
|
check_config(params)
|
||||||
|
|
||||||
|
params["rank"] = 1
|
||||||
|
check_config(params)
|
||||||
|
|
||||||
|
params["keys"] = ["self_attn.k_proj"]
|
||||||
|
check_config(params)
|
||||||
|
|
||||||
|
def test_dora_m_parameter(self):
|
||||||
|
dora_lin = DoRALinear(input_dims=100, output_dims=100)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recomputes m when changing Linear
|
||||||
|
inital_m = dora_lin.m
|
||||||
|
lin = nn.Linear(10, 10)
|
||||||
|
dora_lin.set_linear(lin)
|
||||||
|
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
|
||||||
|
|
||||||
|
# Works with quantized weights
|
||||||
|
quantized_linear = nn.QuantizedLinear(512, 512)
|
||||||
|
dora_lin.set_linear(quantized_linear)
|
||||||
|
dequantized_weight = mx.dequantize(
|
||||||
|
quantized_linear.weight,
|
||||||
|
quantized_linear.scales,
|
||||||
|
quantized_linear.biases,
|
||||||
|
quantized_linear.group_size,
|
||||||
|
quantized_linear.bits,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dora_from_linear(self):
|
||||||
|
in_dims = 256
|
||||||
|
out_dims = 256
|
||||||
|
r = 4
|
||||||
|
|
||||||
|
linear = nn.Linear(in_dims, out_dims)
|
||||||
|
dora_lin = DoRALinear.from_base(linear, r)
|
||||||
|
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
|
||||||
|
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
|
||||||
|
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
|
||||||
|
self.assertEqual(dora_lin.m.shape, (out_dims,))
|
||||||
|
|
||||||
|
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
|
||||||
|
dequantized_weight = mx.dequantize(
|
||||||
|
quantized_linear.weight,
|
||||||
|
quantized_linear.scales,
|
||||||
|
quantized_linear.biases,
|
||||||
|
quantized_linear.group_size,
|
||||||
|
quantized_linear.bits,
|
||||||
|
)
|
||||||
|
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||||
|
)
|
||||||
|
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
|
||||||
|
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
|
||||||
|
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
|
||||||
|
|
||||||
|
def test_dora_to_linear(self):
|
||||||
|
in_dims = 256
|
||||||
|
out_dims = 256
|
||||||
|
r = 4
|
||||||
|
|
||||||
|
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||||
|
dora_lin = DoRALinear.from_base(linear, r)
|
||||||
|
to_linear = dora_lin.fuse()
|
||||||
|
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
|
||||||
|
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
|
||||||
|
|
||||||
|
def dequantize_weight(quantized_linear):
|
||||||
|
return mx.dequantize(
|
||||||
|
quantized_linear.weight,
|
||||||
|
quantized_linear.scales,
|
||||||
|
quantized_linear.biases,
|
||||||
|
quantized_linear.group_size,
|
||||||
|
quantized_linear.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
|
||||||
|
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
|
||||||
|
# Dequantize
|
||||||
|
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dora_dtype(self):
|
||||||
|
in_dims = 256
|
||||||
|
out_dims = 256
|
||||||
|
r = 4
|
||||||
|
|
||||||
|
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||||
|
linear.set_dtype(mx.float16)
|
||||||
|
dora_lin = DoRALinear.from_base(linear, r)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
|
||||||
|
self.assertEqual(dora_lin(x).dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
class TestScheduleConfig(unittest.TestCase):
|
class TestScheduleConfig(unittest.TestCase):
|
||||||
def test_join(self):
|
def test_join(self):
|
||||||
|
55
llms/tests/test_generate.py
Normal file
55
llms/tests/test_generate.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from mlx_lm.utils import generate, load
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
|
||||||
|
def test_generate(self):
|
||||||
|
# Simple test that generation runs
|
||||||
|
text = generate(
|
||||||
|
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_generate_with_logit_bias(self):
|
||||||
|
logit_bias = {0: 2000.0, 1: -20.0}
|
||||||
|
text = generate(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
"hello",
|
||||||
|
max_tokens=5,
|
||||||
|
verbose=False,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
)
|
||||||
|
self.assertEqual(text, "!!!!!")
|
||||||
|
|
||||||
|
def test_generate_with_processor(self):
|
||||||
|
init_toks = self.tokenizer.encode("hello")
|
||||||
|
|
||||||
|
all_toks = None
|
||||||
|
|
||||||
|
def logits_processor(toks, logits):
|
||||||
|
nonlocal all_toks
|
||||||
|
all_toks = toks
|
||||||
|
return logits
|
||||||
|
|
||||||
|
generate(
|
||||||
|
self.model,
|
||||||
|
self.tokenizer,
|
||||||
|
"hello",
|
||||||
|
max_tokens=5,
|
||||||
|
verbose=False,
|
||||||
|
logits_processor=[logits_processor],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -5,6 +5,7 @@ import unittest
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||||
|
from mlx_lm.utils import make_kv_caches
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
@ -100,13 +101,7 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
kv_heads = (
|
cache = make_kv_caches(model)
|
||||||
[model.n_kv_heads] * len(model.layers)
|
|
||||||
if isinstance(model.n_kv_heads, int)
|
|
||||||
else model.n_kv_heads
|
|
||||||
)
|
|
||||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
|
||||||
|
|
||||||
outputs = model(inputs, cache)
|
outputs = model(inputs, cache)
|
||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
@ -397,6 +392,26 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_mamba(self):
|
||||||
|
from mlx_lm.models import mamba
|
||||||
|
|
||||||
|
args = mamba.ModelArgs(
|
||||||
|
model_type="mamba",
|
||||||
|
vocab_size=10000,
|
||||||
|
use_bias=False,
|
||||||
|
use_conv_bias=True,
|
||||||
|
conv_kernel=4,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
state_size=16,
|
||||||
|
intermediate_size=1536,
|
||||||
|
time_step_rank=48,
|
||||||
|
)
|
||||||
|
model = mamba.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
def test_gpt2(self):
|
def test_gpt2(self):
|
||||||
from mlx_lm.models import gpt2
|
from mlx_lm.models import gpt2
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import http
|
import http
|
||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -77,6 +79,19 @@ class TestServer(unittest.TestCase):
|
|||||||
self.assertIn("id", response_body)
|
self.assertIn("id", response_body)
|
||||||
self.assertIn("choices", response_body)
|
self.assertIn("choices", response_body)
|
||||||
|
|
||||||
|
def test_handle_models(self):
|
||||||
|
url = f"http://localhost:{self.port}/v1/models"
|
||||||
|
response = requests.get(url)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
response_body = json.loads(response.text)
|
||||||
|
self.assertEqual(response_body["object"], "list")
|
||||||
|
self.assertIsInstance(response_body["data"], list)
|
||||||
|
self.assertGreater(len(response_body["data"]), 0)
|
||||||
|
model = response_body["data"][0]
|
||||||
|
self.assertIn("id", model)
|
||||||
|
self.assertEqual(model["object"], "model")
|
||||||
|
self.assertIn("created", model)
|
||||||
|
|
||||||
def test_sequence_overlap(self):
|
def test_sequence_overlap(self):
|
||||||
from mlx_lm.server import sequence_overlap
|
from mlx_lm.server import sequence_overlap
|
||||||
|
|
||||||
|
@ -35,6 +35,8 @@ _MODELS = {
|
|||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
|
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
@ -52,6 +54,8 @@ _ALIGNMENT_HEADS = {
|
|||||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
|
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from . import audio, decoding, load_models
|
from . import audio, decoding, load_models
|
||||||
|
from ._version import __version__
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .version import __version__
|
|
||||||
|
@ -12,7 +12,7 @@ with open(package_dir / "requirements.txt") as fid:
|
|||||||
|
|
||||||
sys.path.append(str(package_dir))
|
sys.path.append(str(package_dir))
|
||||||
|
|
||||||
from version import __version__
|
from _version import __version__
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-whisper",
|
name="mlx-whisper",
|
||||||
|
Loading…
Reference in New Issue
Block a user