mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Merge branch 'ml-explore:main' into main
This commit is contained in:
@@ -17,30 +17,6 @@ jobs:
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
|
||||
mlx_lm_build_and_test:
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
python3.8 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install unittest-xml-reporting
|
||||
cd llms/
|
||||
pip install -e ".[testing]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python -m xmlrunner discover -v llms/tests -o test-results/
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
@@ -48,7 +24,6 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- mlx_lm_build_and_test
|
||||
- linux_build_and_test
|
||||
|
||||
prb:
|
||||
@@ -61,7 +36,5 @@ workflows:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mlx_lm_build_and_test:
|
||||
requires: [ hold ]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -6,6 +6,9 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Vim
|
||||
*.swp
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.3.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
|
||||
@@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
- Prince Canuma: Helped add support for `Starcoder2` models.
|
||||
- Shiyu Li: Added the `Segment Anything Model`.
|
||||
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.
|
||||
15
README.md
15
README.md
@@ -4,12 +4,12 @@ This repo contains a variety of standalone examples using the [MLX
|
||||
framework](https://github.com/ml-explore/mlx).
|
||||
|
||||
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
|
||||
|
||||
Some more useful examples are listed below.
|
||||
Some more useful examples are listed below. Check-out [MLX
|
||||
LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python
|
||||
package for LLMs with MLX.
|
||||
|
||||
### Text Models
|
||||
|
||||
- [MLX LM](llms/README.md) a package for LLM text generation, fine-tuning, and more.
|
||||
- [Transformer language model](transformer_lm) training.
|
||||
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
|
||||
[Mistral](llms/mistral), and more in the [LLMs](llms) directory.
|
||||
@@ -20,18 +20,23 @@ Some more useful examples are listed below.
|
||||
|
||||
### Image Models
|
||||
|
||||
- Generating images
|
||||
- [FLUX](flux)
|
||||
- [Stable Diffusion or SDXL](stable_diffusion)
|
||||
- Image classification using [ResNets on CIFAR-10](cifar).
|
||||
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
|
||||
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
|
||||
|
||||
### Audio Models
|
||||
|
||||
- Speech recognition with [OpenAI's Whisper](whisper).
|
||||
- Audio compression and generation with [Meta's EnCodec](encodec).
|
||||
- Music generation with [Meta's MusicGen](musicgen).
|
||||
|
||||
### Multimodal models
|
||||
|
||||
- Joint text and image embeddings with [CLIP](clip).
|
||||
- Text generation from image and text inputs with [LLaVA](llava).
|
||||
- Image segmentation with [Segment Anything (SAM)](segment_anything).
|
||||
|
||||
### Other Models
|
||||
|
||||
@@ -41,7 +46,7 @@ Some more useful examples are listed below.
|
||||
|
||||
### Hugging Face
|
||||
|
||||
Note: You can now directly download a few converted checkpoints from the [MLX
|
||||
You can directly use or download converted checkpoints from the [MLX
|
||||
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
|
||||
We encourage you to join the community and [contribute new
|
||||
models](https://github.com/ml-explore/mlx-examples/issues/155).
|
||||
|
||||
@@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
|
||||
|
||||
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
|
||||
We intend to update this example once these features are added.
|
||||
|
||||
## Distributed training
|
||||
|
||||
The example also supports distributed data parallel training. You can launch a
|
||||
distributed training as follows:
|
||||
|
||||
```shell
|
||||
$ cat >hostfile.json
|
||||
[
|
||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
|
||||
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
|
||||
]
|
||||
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
|
||||
```
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.data.datasets import load_cifar10
|
||||
|
||||
@@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
|
||||
x = x.astype("float32") / 255.0
|
||||
return (x - mean) / std
|
||||
|
||||
group = mx.distributed.init()
|
||||
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||
.to_stream()
|
||||
.image_random_h_flip("image", prob=0.5)
|
||||
.pad("image", 0, 4, 4, 0.0)
|
||||
@@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
|
||||
test_iter = (
|
||||
test.to_stream()
|
||||
.partition_if(group.size() > 1, group.size(), group.rank())
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
return tr_iter, test_iter
|
||||
|
||||
@@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||
|
||||
|
||||
def print_zero(group, *args, **kwargs):
|
||||
if group.rank() != 0:
|
||||
return
|
||||
flush = kwargs.pop("flush", True)
|
||||
print(*args, **kwargs, flush=flush)
|
||||
|
||||
|
||||
def eval_fn(model, inp, tgt):
|
||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||
|
||||
@@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||
return loss, acc
|
||||
|
||||
losses = []
|
||||
accs = []
|
||||
samples_per_sec = []
|
||||
world = mx.distributed.init()
|
||||
losses = 0
|
||||
accuracies = 0
|
||||
samples_per_sec = 0
|
||||
count = 0
|
||||
|
||||
def average_stats(stats, count):
|
||||
if world.size() == 1:
|
||||
return [s / count for s in stats]
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
stats = mx.distributed.all_sum(mx.array(stats))
|
||||
count = mx.distributed.all_sum(count)
|
||||
return (stats / count).tolist()
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
def step(inp, tgt):
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||
grads = nn.utils.average_gradients(grads)
|
||||
optimizer.update(model, grads)
|
||||
return loss, acc
|
||||
|
||||
@@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
loss, acc = step(x, y)
|
||||
mx.eval(state)
|
||||
mx.eval(loss, acc, state)
|
||||
toc = time.perf_counter()
|
||||
loss = loss.item()
|
||||
acc = acc.item()
|
||||
losses.append(loss)
|
||||
accs.append(acc)
|
||||
throughput = x.shape[0] / (toc - tic)
|
||||
samples_per_sec.append(throughput)
|
||||
losses += loss.item()
|
||||
accuracies += acc.item()
|
||||
samples_per_sec += x.shape[0] / (toc - tic)
|
||||
count += 1
|
||||
if batch_counter % 10 == 0:
|
||||
print(
|
||||
l, a, s = average_stats(
|
||||
[losses, accuracies, world.size() * samples_per_sec],
|
||||
count,
|
||||
)
|
||||
print_zero(
|
||||
world,
|
||||
" | ".join(
|
||||
(
|
||||
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||
f"Train loss {loss:.3f}",
|
||||
f"Train acc {acc:.3f}",
|
||||
f"Throughput: {throughput:.2f} images/second",
|
||||
f"Train loss {l:.3f}",
|
||||
f"Train acc {a:.3f}",
|
||||
f"Throughput: {s:.2f} images/second",
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
mean_tr_loss = mx.mean(mx.array(losses))
|
||||
mean_tr_acc = mx.mean(mx.array(accs))
|
||||
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
|
||||
|
||||
|
||||
def test_epoch(model, test_iter, epoch):
|
||||
accs = []
|
||||
accuracies = 0
|
||||
count = 0
|
||||
for batch_counter, batch in enumerate(test_iter):
|
||||
x = mx.array(batch["image"])
|
||||
y = mx.array(batch["label"])
|
||||
acc = eval_fn(model, x, y)
|
||||
acc_value = acc.item()
|
||||
accs.append(acc_value)
|
||||
mean_acc = mx.mean(mx.array(accs))
|
||||
return mean_acc
|
||||
accuracies += acc.item()
|
||||
count += 1
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
accuracies = mx.distributed.all_sum(accuracies)
|
||||
count = mx.distributed.all_sum(count)
|
||||
return (accuracies / count).item()
|
||||
|
||||
|
||||
def main(args):
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
# Initialize the distributed group and report the nodes that showed up
|
||||
world = mx.distributed.init()
|
||||
if world.size() > 1:
|
||||
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
|
||||
|
||||
model = getattr(resnet, args.arch)()
|
||||
|
||||
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.lr)
|
||||
|
||||
train_data, test_data = get_cifar10(args.batch_size)
|
||||
for epoch in range(args.epochs):
|
||||
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||
print(
|
||||
print_zero(
|
||||
world,
|
||||
" | ".join(
|
||||
(
|
||||
f"Epoch: {epoch}",
|
||||
f"avg. Train loss {tr_loss.item():.3f}",
|
||||
f"avg. Train acc {tr_acc.item():.3f}",
|
||||
f"Throughput: {throughput.item():.2f} images/sec",
|
||||
f"avg. Train loss {tr_loss:.3f}",
|
||||
f"avg. Train acc {tr_acc:.3f}",
|
||||
f"Throughput: {throughput:.2f} images/sec",
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
test_acc = test_epoch(model, test_data, epoch)
|
||||
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
|
||||
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
|
||||
|
||||
train_data.reset()
|
||||
test_data.reset()
|
||||
|
||||
@@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
mlx_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("[INFO] Loading")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin")
|
||||
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
|
||||
print("[INFO] Converting")
|
||||
mlx_weights = {
|
||||
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
|
||||
|
||||
56
clip/linear_probe.py
Normal file
56
clip/linear_probe.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Mirror of the Linear Probe Evaluation Script
|
||||
# from the official CLIP Repository.
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from image_processor import CLIPImageProcessor
|
||||
from mlx.data.datasets import load_cifar10
|
||||
from model import CLIPModel
|
||||
from PIL import Image
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root).batch(batch_size)
|
||||
test = load_cifar10(root=root, train=False).batch(batch_size)
|
||||
|
||||
return tr, test
|
||||
|
||||
|
||||
def get_features(model, image_proc, iter):
|
||||
all_features = []
|
||||
all_labels = []
|
||||
|
||||
for batch in tqdm(iter):
|
||||
image, label = batch["image"], batch["label"]
|
||||
x = image_proc([Image.fromarray(im) for im in image])
|
||||
y = mx.array(label)
|
||||
|
||||
image_embeds = model.get_image_features(x)
|
||||
mx.eval(image_embeds)
|
||||
|
||||
all_features.append(image_embeds)
|
||||
all_labels.append(y)
|
||||
|
||||
return mx.concatenate(all_features), mx.concatenate(all_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = CLIPModel.from_pretrained("mlx_model")
|
||||
image_proc = CLIPImageProcessor.from_pretrained("mlx_model")
|
||||
|
||||
train_iter, test_iter = get_cifar10(batch_size=256)
|
||||
train_features, train_labels = get_features(model, image_proc, train_iter)
|
||||
test_features, test_labels = get_features(model, image_proc, test_iter)
|
||||
|
||||
# Perform logistic regression
|
||||
# NOTE: The value of C should be determined via a hyperparameter sweep
|
||||
# using a validation split
|
||||
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
|
||||
classifier.fit(train_features, train_labels)
|
||||
|
||||
# Evaluate using the logistic regression classifier
|
||||
predictions = classifier.predict(test_features)
|
||||
accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
|
||||
print(f"Accuracy = {accuracy:.3f}")
|
||||
@@ -1,4 +1,5 @@
|
||||
mlx
|
||||
mlx-data
|
||||
numpy
|
||||
transformers
|
||||
torch
|
||||
|
||||
84
encodec/README.md
Normal file
84
encodec/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# EnCodec
|
||||
|
||||
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
|
||||
generate audio.
|
||||
|
||||
### Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Optionally install FFmpeg and SciPy for loading and saving audio files,
|
||||
respectively.
|
||||
|
||||
Install [FFmpeg](https://ffmpeg.org/):
|
||||
|
||||
```
|
||||
# on macOS using Homebrew (https://brew.sh/)
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
Install SciPy:
|
||||
|
||||
```
|
||||
pip install scipy
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
An example using the model:
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from encodec import EncodecModel
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
||||
|
||||
# Preprocess the audio (this can also be a list of arrays for batched
|
||||
# processing).
|
||||
feats, mask = processor(audio)
|
||||
|
||||
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||
# compression but lower reconstruction quality.
|
||||
@mx.compile
|
||||
def encode(feats, mask):
|
||||
return model.encode(feats, mask, bandwidth=3)
|
||||
|
||||
# Decode to reconstruct the audio
|
||||
@mx.compile
|
||||
def decode(codes, scales, mask):
|
||||
return model.decode(codes, scales, mask)
|
||||
|
||||
|
||||
codes, scales = encode(feats, mask)
|
||||
reconstructed = decode(codes, scales, mask)
|
||||
|
||||
# Trim any padding:
|
||||
reconstructed = reconstructed[0, : len(audio)]
|
||||
|
||||
# Save the audio as a wave file
|
||||
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
||||
```
|
||||
|
||||
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
|
||||
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
|
||||
in several data types.
|
||||
|
||||
### Optional
|
||||
|
||||
To convert models, use the `convert.py` script. To see the options, run:
|
||||
|
||||
```bash
|
||||
python convert.py -h
|
||||
```
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
|
||||
[code](https://github.com/facebookresearch/encodec) for more details.
|
||||
31
encodec/benchmarks/bench_mx.py
Normal file
31
encodec/benchmarks/bench_mx.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from encodec import EncodecModel
|
||||
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
audio = mx.random.uniform(shape=(288000, 2))
|
||||
feats, mask = processor(audio)
|
||||
mx.eval(model, feats, mask)
|
||||
|
||||
|
||||
@mx.compile
|
||||
def fun():
|
||||
codes, scales = model.encode(feats, mask, bandwidth=3)
|
||||
reconstructed = model.decode(codes, scales, mask)
|
||||
return reconstructed
|
||||
|
||||
|
||||
for _ in range(5):
|
||||
mx.eval(fun())
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(10):
|
||||
mx.eval(fun())
|
||||
toc = time.time()
|
||||
ms = 1000 * (toc - tic) / 10
|
||||
print(f"Time per it: {ms:.3f}")
|
||||
34
encodec/benchmarks/bench_pt.py
Normal file
34
encodec/benchmarks/bench_pt.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoProcessor, EncodecModel
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
||||
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)
|
||||
|
||||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
|
||||
pt_inputs = processor(
|
||||
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
||||
).to("mps")
|
||||
|
||||
|
||||
def fun():
|
||||
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
|
||||
pt_audio = pt_model.decode(
|
||||
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
|
||||
)
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
for _ in range(5):
|
||||
fun()
|
||||
|
||||
tic = time.time()
|
||||
for _ in range(10):
|
||||
fun()
|
||||
toc = time.time()
|
||||
ms = 1000 * (toc - tic) / 10
|
||||
print(f"Time per it: {ms:.3f}")
|
||||
212
encodec/convert.py
Normal file
212
encodec/convert.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import encodec
|
||||
|
||||
|
||||
def fetch_from_hub(hf_repo: str) -> Path:
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=hf_repo,
|
||||
allow_patterns=["*.json", "*.safetensors"],
|
||||
)
|
||||
)
|
||||
return model_path
|
||||
|
||||
|
||||
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
"""
|
||||
Uploads the model to Hugging Face hub.
|
||||
|
||||
Args:
|
||||
path (str): Local path to the model.
|
||||
upload_repo (str): Name of the HF repo to upload to.
|
||||
hf_path (str): Path to the original Hugging Face model.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, ModelCard, logging
|
||||
|
||||
content = dedent(
|
||||
f"""
|
||||
---
|
||||
language: en
|
||||
license: other
|
||||
library: mlx
|
||||
tags:
|
||||
- mlx
|
||||
---
|
||||
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||
converted to MLX format from
|
||||
[{hf_path}](https://huggingface.co/{hf_path}).
|
||||
|
||||
This model is intended to be used with the [EnCodec MLX
|
||||
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
|
||||
"""
|
||||
)
|
||||
|
||||
card = ModelCard(content)
|
||||
card.save(os.path.join(path, "README.md"))
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=path,
|
||||
repo_id=upload_repo,
|
||||
repo_type="model",
|
||||
multi_commits=True,
|
||||
multi_commits_verbose=True,
|
||||
)
|
||||
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
||||
|
||||
|
||||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
mx.save_safetensors(
|
||||
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
|
||||
)
|
||||
|
||||
for weight_name in weights.keys():
|
||||
index_data["weight_map"][weight_name] = "model.safetensors"
|
||||
|
||||
index_data["weight_map"] = {
|
||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||
}
|
||||
|
||||
with open(save_path / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(index_data, f, indent=4)
|
||||
|
||||
|
||||
def save_config(
|
||||
config: dict,
|
||||
config_path: Union[str, Path],
|
||||
) -> None:
|
||||
"""Save the model configuration to the ``config_path``.
|
||||
|
||||
The final configuration will be sorted before saving for better readability.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
config_path (Union[str, Path]): Model configuration file path.
|
||||
"""
|
||||
# Clean unused keys
|
||||
config.pop("_name_or_path", None)
|
||||
|
||||
# sort the config for better readability
|
||||
config = dict(sorted(config.items()))
|
||||
|
||||
# write the updated config to the config_path (if provided)
|
||||
with open(config_path, "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
|
||||
def convert(
|
||||
upload: bool,
|
||||
model: str,
|
||||
dtype: str = None,
|
||||
):
|
||||
hf_repo = f"facebook/encodec_{model}"
|
||||
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
|
||||
path = fetch_from_hub(hf_repo)
|
||||
save_path = Path("mlx_models")
|
||||
|
||||
weights = mx.load(str(Path(path) / "model.safetensors"))
|
||||
|
||||
with open(path / "config.json", "r") as fid:
|
||||
config = SimpleNamespace(**json.load(fid))
|
||||
|
||||
model = encodec.EncodecModel(config)
|
||||
|
||||
new_weights = {}
|
||||
for k, v in weights.items():
|
||||
basename, pname = k.rsplit(".", 1)
|
||||
if pname == "weight_v":
|
||||
g = weights[basename + ".weight_g"]
|
||||
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
|
||||
k = basename + ".weight"
|
||||
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
|
||||
continue
|
||||
elif "lstm" in basename:
|
||||
w_or_b, ih_or_hh, ln = pname.split("_")
|
||||
if w_or_b == "weight":
|
||||
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
|
||||
elif w_or_b == "bias" and ih_or_hh == "ih":
|
||||
continue
|
||||
else:
|
||||
v = v + weights[k.replace("_hh_", "_ih_")]
|
||||
new_pname = "bias"
|
||||
k = basename + "." + ln[1:] + "." + new_pname
|
||||
if "conv.weight" in k:
|
||||
# Possibly a transposed conv which has a different order
|
||||
if "decoder" in k:
|
||||
ln = int(k.split(".")[2])
|
||||
if "conv" in model.decoder.layers[ln] and isinstance(
|
||||
model.decoder.layers[ln].conv, nn.ConvTranspose1d
|
||||
):
|
||||
v = mx.moveaxis(v, 0, 2)
|
||||
else:
|
||||
v = mx.moveaxis(v, 1, 2)
|
||||
else:
|
||||
v = mx.moveaxis(v, 1, 2)
|
||||
|
||||
new_weights[k] = v
|
||||
weights = new_weights
|
||||
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
if dtype is not None:
|
||||
t = getattr(mx, dtype)
|
||||
weights = {k: v.astype(t) for k, v in weights.items()}
|
||||
|
||||
if isinstance(save_path, str):
|
||||
save_path = Path(save_path)
|
||||
|
||||
save_weights(save_path, weights)
|
||||
|
||||
save_config(vars(config), config_path=save_path / "config.json")
|
||||
|
||||
if upload:
|
||||
upload_to_hub(save_path, mlx_repo, hf_repo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="48khz",
|
||||
help="",
|
||||
choices=["24khz", "32khz", "48khz"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload",
|
||||
action="store_true",
|
||||
help="Upload the weights to Hugging Face.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
help="Data type to convert the model to.",
|
||||
default="float32",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(upload=args.upload, model=args.model, dtype=args.dtype)
|
||||
741
encodec/encodec.py
Normal file
741
encodec/encodec.py
Normal file
@@ -0,0 +1,741 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
_lstm_kernel = mx.fast.metal_kernel(
|
||||
name="lstm",
|
||||
input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
|
||||
output_names=["hidden_state", "cell_state"],
|
||||
header="""
|
||||
template <typename T>
|
||||
T sigmoid(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
""",
|
||||
source="""
|
||||
uint b = thread_position_in_grid.x;
|
||||
uint d = hidden_size * 4;
|
||||
|
||||
uint elem = b * d + thread_position_in_grid.y;
|
||||
uint index = elem;
|
||||
uint x_index = b * num_time_steps * d + time_step * d + index;
|
||||
|
||||
auto i = sigmoid(h_in[index] + x[x_index]);
|
||||
index += hidden_size;
|
||||
x_index += hidden_size;
|
||||
auto f = sigmoid(h_in[index] + x[x_index]);
|
||||
index += hidden_size;
|
||||
x_index += hidden_size;
|
||||
auto g = metal::precise::tanh(h_in[index] + x[x_index]);
|
||||
index += hidden_size;
|
||||
x_index += hidden_size;
|
||||
auto o = sigmoid(h_in[index] + x[x_index]);
|
||||
|
||||
cell_state[elem] = f * cell[elem] + i * g;
|
||||
hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def lstm_custom(x, h_in, cell, time_step):
|
||||
assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
|
||||
out_shape = cell.shape
|
||||
return _lstm_kernel(
|
||||
inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
|
||||
output_shapes=[out_shape, out_shape],
|
||||
output_dtypes=[h_in.dtype, h_in.dtype],
|
||||
grid=(x.shape[0], h_in.size // 4, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.Wx = mx.zeros((4 * hidden_size, input_size))
|
||||
self.Wh = mx.zeros((4 * hidden_size, hidden_size))
|
||||
self.bias = mx.zeros((4 * hidden_size,)) if bias else None
|
||||
|
||||
def __call__(self, x, hidden=None, cell=None):
|
||||
if self.bias is not None:
|
||||
x = mx.addmm(self.bias, x, self.Wx.T)
|
||||
else:
|
||||
x = x @ self.Wx.T
|
||||
|
||||
all_hidden = []
|
||||
|
||||
B = x.shape[0]
|
||||
cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
|
||||
for t in range(x.shape[-2]):
|
||||
if hidden is None:
|
||||
hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
|
||||
else:
|
||||
hidden = hidden @ self.Wh.T
|
||||
hidden, cell = lstm_custom(x, hidden, cell, t)
|
||||
all_hidden.append(hidden)
|
||||
|
||||
return mx.stack(all_hidden, axis=-2)
|
||||
|
||||
|
||||
class EncodecConv1d(nn.Module):
|
||||
"""Conv1d with asymmetric or causal padding and normalization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.causal = config.use_causal_conv
|
||||
self.pad_mode = config.pad_mode
|
||||
self.norm_type = config.norm_type
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation=dilation
|
||||
)
|
||||
if self.norm_type == "time_group_norm":
|
||||
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||
|
||||
self.stride = stride
|
||||
|
||||
# Effective kernel size with dilations.
|
||||
self.kernel_size = (kernel_size - 1) * dilation + 1
|
||||
|
||||
self.padding_total = kernel_size - stride
|
||||
|
||||
def _get_extra_padding_for_conv1d(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
length = hidden_states.shape[1]
|
||||
n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
|
||||
n_frames = int(math.ceil(n_frames)) - 1
|
||||
ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
|
||||
return ideal_length - length
|
||||
|
||||
def _pad1d(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
paddings: Tuple[int, int],
|
||||
mode: str = "zero",
|
||||
value: float = 0.0,
|
||||
):
|
||||
if mode != "reflect":
|
||||
return mx.pad(
|
||||
hidden_states, paddings, mode="constant", constant_values=value
|
||||
)
|
||||
|
||||
length = hidden_states.shape[1]
|
||||
prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
|
||||
suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
|
||||
return mx.concatenate([prefix, hidden_states, suffix], axis=1)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
|
||||
|
||||
if self.causal:
|
||||
# Left padding for causal
|
||||
hidden_states = self._pad1d(
|
||||
hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
|
||||
)
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = self.padding_total // 2
|
||||
padding_left = self.padding_total - padding_right
|
||||
hidden_states = self._pad1d(
|
||||
hidden_states,
|
||||
(padding_left, padding_right + extra_padding),
|
||||
mode=self.pad_mode,
|
||||
)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
if self.norm_type == "time_group_norm":
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncodecConvTranspose1d(nn.Module):
|
||||
"""ConvTranspose1d with asymmetric or causal padding and normalization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.causal = config.use_causal_conv
|
||||
self.trim_right_ratio = config.trim_right_ratio
|
||||
self.norm_type = config.norm_type
|
||||
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
||||
if config.norm_type == "time_group_norm":
|
||||
self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
|
||||
self.padding_total = kernel_size - stride
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
if self.norm_type == "time_group_norm":
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if self.causal:
|
||||
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
||||
else:
|
||||
padding_right = self.padding_total // 2
|
||||
|
||||
padding_left = self.padding_total - padding_right
|
||||
|
||||
end = hidden_states.shape[1] - padding_right
|
||||
hidden_states = hidden_states[:, padding_left:end, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncodecLSTM(nn.Module):
|
||||
def __init__(self, config, dimension):
|
||||
super().__init__()
|
||||
self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
h = hidden_states
|
||||
for lstm in self.lstm:
|
||||
h = lstm(h)
|
||||
return h + hidden_states
|
||||
|
||||
|
||||
class EncodecResnetBlock(nn.Module):
|
||||
"""
|
||||
Residual block from SEANet model as used by EnCodec.
|
||||
"""
|
||||
|
||||
def __init__(self, config, dim: int, dilations: List[int]):
|
||||
super().__init__()
|
||||
kernel_sizes = (config.residual_kernel_size, 1)
|
||||
if len(kernel_sizes) != len(dilations):
|
||||
raise ValueError("Number of kernel sizes should match number of dilations")
|
||||
|
||||
hidden = dim // config.compress
|
||||
block = []
|
||||
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
||||
in_chs = dim if i == 0 else hidden
|
||||
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
||||
block += [nn.ELU()]
|
||||
block += [
|
||||
EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
|
||||
]
|
||||
self.block = block
|
||||
|
||||
if getattr(config, "use_conv_shortcut", True):
|
||||
self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
residual = hidden_states
|
||||
for layer in self.block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return self.shortcut(residual) + hidden_states
|
||||
|
||||
|
||||
class EncodecEncoder(nn.Module):
|
||||
"""SEANet encoder as used by EnCodec."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
model = [
|
||||
EncodecConv1d(
|
||||
config, config.audio_channels, config.num_filters, config.kernel_size
|
||||
)
|
||||
]
|
||||
scaling = 1
|
||||
|
||||
for ratio in reversed(config.upsampling_ratios):
|
||||
current_scale = scaling * config.num_filters
|
||||
for j in range(config.num_residual_layers):
|
||||
model += [
|
||||
EncodecResnetBlock(
|
||||
config, current_scale, [config.dilation_growth_rate**j, 1]
|
||||
)
|
||||
]
|
||||
model += [nn.ELU()]
|
||||
model += [
|
||||
EncodecConv1d(
|
||||
config,
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
kernel_size=ratio * 2,
|
||||
stride=ratio,
|
||||
)
|
||||
]
|
||||
scaling *= 2
|
||||
|
||||
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
||||
model += [nn.ELU()]
|
||||
model += [
|
||||
EncodecConv1d(
|
||||
config,
|
||||
scaling * config.num_filters,
|
||||
config.hidden_size,
|
||||
config.last_kernel_size,
|
||||
)
|
||||
]
|
||||
|
||||
self.layers = model
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncodecDecoder(nn.Module):
|
||||
"""SEANet decoder as used by EnCodec."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
scaling = int(2 ** len(config.upsampling_ratios))
|
||||
model = [
|
||||
EncodecConv1d(
|
||||
config,
|
||||
config.hidden_size,
|
||||
scaling * config.num_filters,
|
||||
config.kernel_size,
|
||||
)
|
||||
]
|
||||
|
||||
model += [EncodecLSTM(config, scaling * config.num_filters)]
|
||||
|
||||
for ratio in config.upsampling_ratios:
|
||||
current_scale = scaling * config.num_filters
|
||||
model += [nn.ELU()]
|
||||
model += [
|
||||
EncodecConvTranspose1d(
|
||||
config,
|
||||
current_scale,
|
||||
current_scale // 2,
|
||||
kernel_size=ratio * 2,
|
||||
stride=ratio,
|
||||
)
|
||||
]
|
||||
for j in range(config.num_residual_layers):
|
||||
model += [
|
||||
EncodecResnetBlock(
|
||||
config, current_scale // 2, (config.dilation_growth_rate**j, 1)
|
||||
)
|
||||
]
|
||||
scaling //= 2
|
||||
|
||||
model += [nn.ELU()]
|
||||
model += [
|
||||
EncodecConv1d(
|
||||
config,
|
||||
config.num_filters,
|
||||
config.audio_channels,
|
||||
config.last_kernel_size,
|
||||
)
|
||||
]
|
||||
self.layers = model
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncodecEuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.embed = mx.zeros((config.codebook_size, config.codebook_dim))
|
||||
|
||||
def quantize(self, hidden_states):
|
||||
embed = self.embed.T
|
||||
scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
|
||||
dist = -(
|
||||
scaled_states
|
||||
- 2 * hidden_states @ embed
|
||||
+ embed.square().sum(axis=0, keepdims=True)
|
||||
)
|
||||
embed_ind = dist.argmax(axis=-1)
|
||||
return embed_ind
|
||||
|
||||
def encode(self, hidden_states):
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape((-1, shape[-1]))
|
||||
embed_ind = self.quantize(hidden_states)
|
||||
embed_ind = embed_ind.reshape(*shape[:-1])
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind):
|
||||
return self.embed[embed_ind]
|
||||
|
||||
|
||||
class EncodecVectorQuantization(nn.Module):
|
||||
"""
|
||||
Vector quantization implementation. Currently supports only euclidean distance.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.codebook = EncodecEuclideanCodebook(config)
|
||||
|
||||
def encode(self, hidden_states):
|
||||
return self.codebook.encode(hidden_states)
|
||||
|
||||
def decode(self, embed_ind):
|
||||
return self.codebook.decode(embed_ind)
|
||||
|
||||
|
||||
class EncodecResidualVectorQuantizer(nn.Module):
|
||||
"""Residual Vector Quantizer."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.codebook_size = config.codebook_size
|
||||
|
||||
hop_length = np.prod(config.upsampling_ratios)
|
||||
self.frame_rate = math.ceil(config.sampling_rate / hop_length)
|
||||
self.num_quantizers = int(
|
||||
1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
|
||||
)
|
||||
self.layers = [
|
||||
EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
|
||||
]
|
||||
|
||||
def get_num_quantizers_for_bandwidth(
|
||||
self, bandwidth: Optional[float] = None
|
||||
) -> int:
|
||||
"""Return num_quantizers based on specified target bandwidth."""
|
||||
bw_per_q = math.log2(self.codebook_size) * self.frame_rate
|
||||
num_quantizers = self.num_quantizers
|
||||
if bandwidth is not None and bandwidth > 0.0:
|
||||
num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
|
||||
return num_quantizers
|
||||
|
||||
def encode(
|
||||
self, embeddings: mx.array, bandwidth: Optional[float] = None
|
||||
) -> mx.array:
|
||||
"""
|
||||
Encode a given input array with the specified frame rate at the given
|
||||
bandwidth. The RVQ encode method sets the appropriate number of
|
||||
quantizers to use and returns indices for each quantizer.
|
||||
"""
|
||||
num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
|
||||
residual = embeddings
|
||||
all_indices = []
|
||||
for layer in self.layers[:num_quantizers]:
|
||||
indices = layer.encode(residual)
|
||||
quantized = layer.decode(indices)
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = mx.stack(all_indices, axis=1)
|
||||
return out_indices
|
||||
|
||||
def decode(self, codes: mx.array) -> mx.array:
|
||||
"""Decode the given codes to the quantized representation."""
|
||||
quantized_out = None
|
||||
for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices.squeeze(1))
|
||||
if quantized_out is None:
|
||||
quantized_out = quantized
|
||||
else:
|
||||
quantized_out = quantized + quantized_out
|
||||
return quantized_out
|
||||
|
||||
|
||||
class EncodecModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.encoder = EncodecEncoder(config)
|
||||
self.decoder = EncodecDecoder(config)
|
||||
self.quantizer = EncodecResidualVectorQuantizer(config)
|
||||
|
||||
def _encode_frame(
|
||||
self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""
|
||||
Encodes the given input using the underlying VQVAE.
|
||||
"""
|
||||
length = input_values.shape[1]
|
||||
duration = length / self.config.sampling_rate
|
||||
|
||||
if (
|
||||
self.config.chunk_length_s is not None
|
||||
and duration > 1e-5 + self.config.chunk_length_s
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
|
||||
)
|
||||
|
||||
scale = None
|
||||
if self.config.normalize:
|
||||
# if the padding is non zero
|
||||
input_values = input_values * padding_mask[..., None]
|
||||
mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
|
||||
scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
|
||||
input_values = input_values / scale
|
||||
|
||||
embeddings = self.encoder(input_values)
|
||||
codes = self.quantizer.encode(embeddings, bandwidth)
|
||||
return codes, scale
|
||||
|
||||
def encode(
|
||||
self,
|
||||
input_values: mx.array,
|
||||
padding_mask: mx.array = None,
|
||||
bandwidth: Optional[float] = None,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""
|
||||
Encodes the input audio waveform into discrete codes.
|
||||
|
||||
Args:
|
||||
input_values (mx.array): The input audio waveform with shape
|
||||
``(batch_size, channels, sequence_length)``.
|
||||
padding_mask (mx.array): Padding mask used to pad the ``input_values``.
|
||||
bandwidth (float, optional): The target bandwidth. Must be one of
|
||||
``config.target_bandwidths``. If ``None``, uses the smallest
|
||||
possible bandwidth. bandwidth is represented as a thousandth of
|
||||
what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0
|
||||
|
||||
Returns:
|
||||
A list of frames containing the discrete encoded codes for the
|
||||
input audio waveform, along with rescaling factors for each chunk
|
||||
when ``config.normalize==True``. Each frame is a tuple ``(codebook,
|
||||
scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
|
||||
frames)``.
|
||||
"""
|
||||
|
||||
if bandwidth is None:
|
||||
bandwidth = self.config.target_bandwidths[0]
|
||||
if bandwidth not in self.config.target_bandwidths:
|
||||
raise ValueError(
|
||||
f"This model doesn't support the bandwidth {bandwidth}. "
|
||||
f"Select one of {self.config.target_bandwidths}."
|
||||
)
|
||||
|
||||
_, input_length, channels = input_values.shape
|
||||
|
||||
if channels < 1 or channels > 2:
|
||||
raise ValueError(
|
||||
f"Number of audio channels must be 1 or 2, but got {channels}"
|
||||
)
|
||||
|
||||
chunk_length = self.chunk_length
|
||||
if chunk_length is None:
|
||||
chunk_length = input_length
|
||||
stride = input_length
|
||||
else:
|
||||
stride = self.chunk_stride
|
||||
|
||||
if padding_mask is None:
|
||||
padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
|
||||
encoded_frames = []
|
||||
scales = []
|
||||
|
||||
step = chunk_length - stride
|
||||
if (input_length % stride) != step:
|
||||
raise ValueError(
|
||||
"The input length is not properly padded for batched chunked "
|
||||
"encoding. Make sure to pad the input correctly."
|
||||
)
|
||||
|
||||
for offset in range(0, input_length - step, stride):
|
||||
mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
|
||||
frame = input_values[:, offset : offset + chunk_length]
|
||||
encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
|
||||
encoded_frames.append(encoded_frame)
|
||||
scales.append(scale)
|
||||
|
||||
encoded_frames = mx.stack(encoded_frames)
|
||||
|
||||
return (encoded_frames, scales)
|
||||
|
||||
@staticmethod
|
||||
def _linear_overlap_add(frames: List[mx.array], stride: int):
|
||||
if len(frames) == 0:
|
||||
raise ValueError("`frames` cannot be an empty list.")
|
||||
|
||||
dtype = frames[0].dtype
|
||||
N, frame_length, C = frames[0].shape
|
||||
total_size = stride * (len(frames) - 1) + frames[-1].shape[1]
|
||||
|
||||
time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
|
||||
weight = 0.5 - (time_vec - 0.5).abs()
|
||||
|
||||
weight = weight[:, None]
|
||||
sum_weight = mx.zeros((total_size, 1), dtype=dtype)
|
||||
out = mx.zeros((N, total_size, C), dtype=dtype)
|
||||
offset = 0
|
||||
|
||||
for frame in frames:
|
||||
frame_length = frame.shape[1]
|
||||
out[:, offset : offset + frame_length] += weight[:frame_length] * frame
|
||||
sum_weight[offset : offset + frame_length] += weight[:frame_length]
|
||||
offset += stride
|
||||
|
||||
return out / sum_weight
|
||||
|
||||
def _decode_frame(
|
||||
self, codes: mx.array, scale: Optional[mx.array] = None
|
||||
) -> mx.array:
|
||||
embeddings = self.quantizer.decode(codes)
|
||||
outputs = self.decoder(embeddings)
|
||||
if scale is not None:
|
||||
outputs = outputs * scale
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return self.config.audio_channels
|
||||
|
||||
@property
|
||||
def sampling_rate(self):
|
||||
return self.config.sampling_rate
|
||||
|
||||
@property
|
||||
def chunk_length(self):
|
||||
if self.config.chunk_length_s is None:
|
||||
return None
|
||||
else:
|
||||
return int(self.config.chunk_length_s * self.config.sampling_rate)
|
||||
|
||||
@property
|
||||
def chunk_stride(self):
|
||||
if self.config.chunk_length_s is None or self.config.overlap is None:
|
||||
return None
|
||||
else:
|
||||
return max(1, int((1.0 - self.config.overlap) * self.chunk_length))
|
||||
|
||||
def decode(
|
||||
self,
|
||||
audio_codes: mx.array,
|
||||
audio_scales: Union[mx.array, List[mx.array]],
|
||||
padding_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Decodes the given frames into an output audio waveform.
|
||||
|
||||
Note that the output might be a bit bigger than the input. In that
|
||||
case, any extra steps at the end should be trimmed.
|
||||
|
||||
Args:
|
||||
audio_codes (mx.array): Discret code embeddings of shape
|
||||
``(batch_size, nb_chunks, chunk_length)``.
|
||||
audio_scales (mx.array): Scaling factor for each input.
|
||||
padding_mask (mx.array): Padding mask.
|
||||
"""
|
||||
chunk_length = self.chunk_length
|
||||
if chunk_length is None:
|
||||
if audio_codes.shape[1] != 1:
|
||||
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
||||
audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
|
||||
else:
|
||||
decoded_frames = []
|
||||
|
||||
for frame, scale in zip(audio_codes, audio_scales):
|
||||
frames = self._decode_frame(frame, scale)
|
||||
decoded_frames.append(frames)
|
||||
|
||||
audio_values = self._linear_overlap_add(
|
||||
decoded_frames, self.chunk_stride or 1
|
||||
)
|
||||
|
||||
# truncate based on padding mask
|
||||
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
|
||||
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||
return audio_values
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_or_repo: str):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
||||
|
||||
def preprocess_audio(
|
||||
raw_audio: Union[mx.array, List[mx.array]],
|
||||
sampling_rate: int = 24000,
|
||||
chunk_length: Optional[int] = None,
|
||||
chunk_stride: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Prepare inputs for the EnCodec model.
|
||||
|
||||
Args:
|
||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||
sequences to be processed.
|
||||
sampling_rate (int): The sampling rate at which the audio waveform
|
||||
should be digitalized.
|
||||
chunk_length (int, optional): The model's chunk length.
|
||||
chunk_stride (int, optional): The model's chunk stride.
|
||||
"""
|
||||
if not isinstance(raw_audio, list):
|
||||
raw_audio = [raw_audio]
|
||||
|
||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
if chunk_length is not None:
|
||||
max_length += chunk_length - (max_length % chunk_stride)
|
||||
|
||||
inputs = []
|
||||
masks = []
|
||||
for x in raw_audio:
|
||||
length = x.shape[0]
|
||||
mask = mx.ones((length,), dtype=mx.bool_)
|
||||
difference = max_length - length
|
||||
if difference > 0:
|
||||
mask = mx.pad(mask, (0, difference))
|
||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||
inputs.append(x)
|
||||
masks.append(mask)
|
||||
return mx.stack(inputs), mx.stack(masks)
|
||||
39
encodec/example.py
Normal file
39
encodec/example.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
from encodec import EncodecModel
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||
|
||||
# Preprocess the audio (this can also be a list of arrays for batched
|
||||
# processing).
|
||||
feats, mask = processor(audio)
|
||||
|
||||
|
||||
# Encode at the given bandwidth. A lower bandwidth results in more
|
||||
# compression but lower reconstruction quality.
|
||||
@mx.compile
|
||||
def encode(feats, mask):
|
||||
return model.encode(feats, mask, bandwidth=3)
|
||||
|
||||
|
||||
# Decode to reconstruct the audio
|
||||
@mx.compile
|
||||
def decode(codes, scales, mask):
|
||||
return model.decode(codes, scales, mask)
|
||||
|
||||
|
||||
codes, scales = encode(feats, mask)
|
||||
reconstructed = decode(codes, scales, mask)
|
||||
|
||||
# Trim any padding:
|
||||
reconstructed = reconstructed[0, : len(audio)]
|
||||
|
||||
# Save the audio as a wave file
|
||||
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
|
||||
3
encodec/requirements.txt
Normal file
3
encodec/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
mlx>=0.18
|
||||
numpy
|
||||
huggingface_hub
|
||||
67
encodec/test.py
Normal file
67
encodec/test.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
from transformers import EncodecModel as PTEncodecModel
|
||||
|
||||
from encodec import EncodecModel, preprocess_audio
|
||||
|
||||
|
||||
def compare_processors():
|
||||
np.random.seed(0)
|
||||
audio_length = 95500
|
||||
audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
||||
|
||||
pt_inputs = processor(
|
||||
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
mx_inputs = preprocess_audio(
|
||||
mx.array(audio).T,
|
||||
processor.sampling_rate,
|
||||
processor.chunk_length,
|
||||
processor.chunk_stride,
|
||||
)
|
||||
|
||||
assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
|
||||
assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])
|
||||
|
||||
|
||||
def compare_models():
|
||||
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
np.random.seed(0)
|
||||
audio_length = 190560
|
||||
audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
|
||||
mask = np.ones((1, audio_length), dtype=np.int32)
|
||||
pt_encoded = pt_model.encode(
|
||||
torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
|
||||
)
|
||||
mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
|
||||
pt_codes = pt_encoded.audio_codes.numpy()
|
||||
mx_codes = mx_encoded[0]
|
||||
assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"
|
||||
|
||||
for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
|
||||
if mx_scale is not None:
|
||||
pt_scale = pt_scale.numpy()
|
||||
assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)
|
||||
|
||||
pt_audio = pt_model.decode(
|
||||
pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
|
||||
)
|
||||
pt_audio = pt_audio[0].squeeze().T.detach().numpy()
|
||||
mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
|
||||
mx_audio = mx_audio.squeeze()
|
||||
assert np.allclose(
|
||||
pt_audio, mx_audio, atol=1e-4, rtol=1e-4
|
||||
), "Decoding audio mismatch"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
compare_processors()
|
||||
compare_models()
|
||||
52
encodec/utils.py
Normal file
52
encodec/utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
"""
|
||||
Save audio to a wave (.wav) file.
|
||||
"""
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
audio = (audio * 32767).astype(mx.int16)
|
||||
write(file, sampling_rate, np.array(audio))
|
||||
|
||||
|
||||
def load_audio(file: str, sampling_rate: int, channels: int):
|
||||
"""
|
||||
Read audio into an mx.array, resampling if necessary.
|
||||
|
||||
Args:
|
||||
file (str): The audio file to open.
|
||||
sampling_rate (int): The sample rate to resample the audio at if needed.
|
||||
channels (int): The number of audio channels.
|
||||
|
||||
Returns:
|
||||
An mx.array containing the audio waveform in float32.
|
||||
"""
|
||||
from subprocess import CalledProcessError, run
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", str(channels),
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sampling_rate),
|
||||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
out = mx.array(np.frombuffer(out, np.int16))
|
||||
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
||||
281
flux/README.md
Normal file
281
flux/README.md
Normal file
@@ -0,0 +1,281 @@
|
||||
FLUX
|
||||
====
|
||||
|
||||
FLUX implementation in MLX. The implementation is ported directly from
|
||||
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
|
||||
and the model weights are downloaded directly from the Hugging Face Hub.
|
||||
|
||||
The goal of this example is to be clean, educational and to allow for
|
||||
experimentation with finetuning FLUX models as well as adding extra
|
||||
functionality such as in-/outpainting, guidance with custom losses etc.
|
||||
|
||||

|
||||
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
|
||||
tron emanating futuristic technology with the word "MLX" in the center with
|
||||
capital red letters.'*
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
The dependencies are minimal, namely:
|
||||
|
||||
- `huggingface-hub` to download the checkpoints.
|
||||
- `regex` for the tokenization
|
||||
- `tqdm`, `PIL`, and `numpy` for the scripts
|
||||
- `sentencepiece` for the T5 tokenizer
|
||||
- `datasets` for using an HF dataset directly
|
||||
|
||||
You can install all of the above with the `requirements.txt` as follows:
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
|
||||
Usage
|
||||
---------
|
||||
|
||||
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
|
||||
|
||||
```shell
|
||||
python txt2image.py --model schnell \
|
||||
--n-images 1 \
|
||||
--image-size 256x512 \
|
||||
--verbose \
|
||||
'A photo of an astronaut riding a horse on Mars.'
|
||||
```
|
||||
|
||||
For more parameters, please use the `--help` command to view.
|
||||
|
||||
```shell
|
||||
python txt2image.py --help
|
||||
```
|
||||
|
||||
Inference
|
||||
---------
|
||||
|
||||
Inference in this example is similar to the stable diffusion example. The
|
||||
classes to get you started are `FluxPipeline` from the `flux` module.
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from flux import FluxPipeline
|
||||
|
||||
# This will download all the weights from HF hub
|
||||
flux = FluxPipeline("flux-schnell")
|
||||
|
||||
# Make a generator that returns the latent variables from the reverse diffusion
|
||||
# process
|
||||
latent_generator = flux.generate_latents(
|
||||
"A photo of an astronaut riding a horse on Mars",
|
||||
num_steps=4,
|
||||
latent_size=(32, 64), # 256x512 image
|
||||
)
|
||||
|
||||
# The first return value of the generator contains the conditioning and the
|
||||
# random noise at the beginning of the diffusion process.
|
||||
conditioning = next(latent_generator)
|
||||
(
|
||||
x_T, # The initial noise
|
||||
x_positions, # The integer positions used for image positional encoding
|
||||
t5_conditioning, # The T5 features from the text prompt
|
||||
t5_positions, # Integer positions for text (normally all 0s)
|
||||
clip_conditioning, # The clip text features from the text prompt
|
||||
) = conditioning
|
||||
|
||||
# Returning the conditioning as the first output from the generator allows us
|
||||
# to unload T5 and clip before running the diffusion transformer.
|
||||
mx.eval(conditioning)
|
||||
|
||||
# Evaluate each diffusion step
|
||||
for x_t in latent_generator:
|
||||
mx.eval(x_t)
|
||||
|
||||
# Note that we need to pass the latent size because it is collapsed and
|
||||
# patchified in x_t and we need to unwrap it.
|
||||
img = flux.decode(x_t, latent_size=(32, 64))
|
||||
```
|
||||
|
||||
The above are essentially the implementation of the `txt2image.py` script
|
||||
except for some additional logic to quantize and/or load trained adapters. One
|
||||
can use the script as follows:
|
||||
|
||||
```shell
|
||||
python txt2image.py \
|
||||
--n-images 4 \
|
||||
--n-rows 2 \
|
||||
--image-size 256x512 \
|
||||
'A photo of an astronaut riding a horse on Mars.'
|
||||
```
|
||||
|
||||
### Experimental Options
|
||||
|
||||
FLUX pads the prompt to a specific size of 512 tokens for the dev model and
|
||||
256 for the schnell model. Not applying padding results in faster generation
|
||||
but it is not clear how it may affect the generated images. To enable that
|
||||
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
|
||||
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
|
||||
|
||||
Finetuning
|
||||
----------
|
||||
|
||||
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
|
||||
but ymmv) on a provided image dataset. The dataset folder must have an
|
||||
`train.jsonl` file with the following format:
|
||||
|
||||
```jsonl
|
||||
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
||||
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
||||
...
|
||||
```
|
||||
|
||||
The training script by default trains for 600 iterations with a batch size of
|
||||
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
|
||||
--help` for the list of hyperparameters you can tune.
|
||||
|
||||
> [!Note]
|
||||
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
|
||||
> should reduce this number significantly.
|
||||
|
||||
### Training Example
|
||||
|
||||
This is a step-by-step finetuning example. We will be using the data from
|
||||
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
|
||||
In particular, we will use `dog6` which is a popular example for showcasing
|
||||
dreambooth [^1].
|
||||
|
||||
The training images are the following 5 images [^2]:
|
||||
|
||||

|
||||
|
||||
We start by making the following `train.jsonl` file and placing it in the same
|
||||
folder as the images.
|
||||
|
||||
```jsonl
|
||||
{"image": "00.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "01.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "02.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "03.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "04.jpg", "prompt": "A photo of sks dog"}
|
||||
```
|
||||
|
||||
Subsequently we finetune FLUX using the following command:
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
path/to/dreambooth/dataset/dog6
|
||||
```
|
||||
|
||||
Or you can directly use the pre-processed Hugging Face dataset
|
||||
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
|
||||
for fine-tuning.
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
|
||||
bit more than 1 hour.
|
||||
|
||||
### Using the Adapter
|
||||
|
||||
The adapters are saved in `mlx_output` and can be used directly by the
|
||||
`txt2image.py` script. For instance,
|
||||
|
||||
```shell
|
||||
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
|
||||
--adapter mlx_output/final_adapters.safetensors \
|
||||
--fuse-adapter \
|
||||
--no-t5-padding \
|
||||
'A photo of an sks dog lying on the sand at a beach in Greece'
|
||||
```
|
||||
|
||||
generates an image that looks like the following,
|
||||
|
||||

|
||||
|
||||
and of course we can pass `--image-size 512x1024` to get larger images with
|
||||
different aspect ratios,
|
||||
|
||||

|
||||
|
||||
The arguments that are relevant to the adapters are of course `--adapter` and
|
||||
`--fuse-adapter`. The first defines the path to an adapter to apply to the
|
||||
model and the second fuses the adapter back into the model to get a bit more
|
||||
speed during generation.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
||||
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
||||
|
||||
|
||||
Distributed Computation
|
||||
------------------------
|
||||
|
||||
The FLUX example supports distributed computation during both generation and
|
||||
training. See the [distributed communication
|
||||
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
|
||||
for information on how to set-up MLX for distributed communication. The rest of
|
||||
this section assumes you can launch distributed MLX programs using `mlx.launch
|
||||
--hostfile hostfile.json`.
|
||||
|
||||
### Distributed Finetuning
|
||||
|
||||
Distributed finetuning scales very well with FLUX and all one has to do is
|
||||
adjust the gradient accumulation and training iterations so that the batch
|
||||
size remains the same. For instance, to replicate the following training
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
On 4 machines we simply run
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 2 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
|
||||
|
||||
### Distributed Inference
|
||||
|
||||
Distributed inference can be divided in two different approaches. The first
|
||||
approach is the data-parallel approach, where each node generates its own
|
||||
images and shares them at the end. The second approach is the model-parallel
|
||||
approach where the model is shared across the nodes and they collaboratively
|
||||
generate the images.
|
||||
|
||||
The `txt2image.py` script will attempt to choose the best approach depending on
|
||||
how many images are being generated across the nodes. The model-parallel
|
||||
approach can be forced by passing the argument `--force-shard`.
|
||||
|
||||
For better performance in the model-parallel approach we suggest that you use a
|
||||
[thunderbolt
|
||||
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
|
||||
|
||||
All you have to do once again is use `mlx.launch` as follows
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- \
|
||||
python txt2image.py --model schnell \
|
||||
--n-images 8 \
|
||||
--image-size 512x512 \
|
||||
--verbose \
|
||||
'A photo of an astronaut riding a horse on Mars'
|
||||
```
|
||||
|
||||
for model-parallel generation you may want to also pass `--env
|
||||
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
|
||||
reduces the CPU/GPU synchronization overhead.
|
||||
292
flux/dreambooth.py
Normal file
292
flux/dreambooth.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
|
||||
from flux import FluxPipeline, Trainer, load_dataset, save_config
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
"""Generate images to monitor the progress of the finetuning."""
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
|
||||
# Generate some images and arrange them in a grid
|
||||
n_rows = 2
|
||||
n_images = 4
|
||||
x = flux.generate_images(
|
||||
args.progress_prompt,
|
||||
n_images,
|
||||
args.progress_steps,
|
||||
)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(out_file)
|
||||
|
||||
|
||||
def save_adapters(adapter_name, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / adapter_name
|
||||
print(f"Saving {str(out_file)}")
|
||||
|
||||
mx.save_safetensors(
|
||||
str(out_file),
|
||||
dict(tree_flatten(flux.flow.trainable_parameters())),
|
||||
metadata={
|
||||
"lora_rank": str(args.lora_rank),
|
||||
"lora_blocks": str(args.lora_blocks),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Finetune Flux to generate images with a specific subject"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="dev",
|
||||
choices=[
|
||||
"dev",
|
||||
"schnell",
|
||||
],
|
||||
help="Which flux model to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance", type=float, default=4.0, help="The guidance factor to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=600,
|
||||
help="How many iterations to train for",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size to use when training the stable diffusion model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=lambda x: tuple(map(int, x.split("x"))),
|
||||
default=(512, 512),
|
||||
help="The resolution of the training images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-augmentations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Augment the images by random cropping and panning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-prompt",
|
||||
required=True,
|
||||
help="Use this prompt when generating images for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Use this many steps when generating images for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-every",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Generate images every PROGRESS_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Save the model every CHECKPOINT_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-blocks",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Train the last LORA_BLOCKS transformer blocks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-accumulate",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Accumulate gradients for that many iterations before applying them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||
)
|
||||
|
||||
parser.add_argument("dataset")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
output_path = Path(args.output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
save_config(vars(args), output_path / "adapter_config.json")
|
||||
|
||||
# Load the model and set it up for LoRA training. We use the same random
|
||||
# state when creating the LoRA layers so all workers will have the same
|
||||
# initial weights.
|
||||
mx.random.seed(0x0F0F0F0F)
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.flow.freeze()
|
||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||
|
||||
# Reset the seed to a different seed per worker if we are in distributed
|
||||
# mode so that each worker is working on different data, diffusion step and
|
||||
# random noise.
|
||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
||||
|
||||
# Report how many parameters we are training
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
)
|
||||
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
|
||||
|
||||
# Set up the optimizer and training steps. The steps are a bit verbose to
|
||||
# support gradient accumulation together with compilation.
|
||||
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
|
||||
cosine = optim.cosine_decay(
|
||||
args.learning_rate, args.iterations // args.grad_accumulate
|
||||
)
|
||||
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
|
||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def single_step(x, t5_feat, clip_feat, guidance):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = average_gradients(grads)
|
||||
optimizer.update(flux.flow, grads)
|
||||
|
||||
return loss
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
|
||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
return loss, grads
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = tree_map(
|
||||
lambda a, b: (a + b) / args.grad_accumulate,
|
||||
prev_grads,
|
||||
grads,
|
||||
)
|
||||
grads = average_gradients(grads)
|
||||
optimizer.update(flux.flow, grads)
|
||||
|
||||
return loss
|
||||
|
||||
# We simply route to the appropriate step based on whether we have
|
||||
# gradients from a previous step and whether we should be performing an
|
||||
# update or simply computing and accumulating gradients in this step.
|
||||
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
|
||||
if prev_grads is None:
|
||||
if perform_step:
|
||||
return single_step(x, t5_feat, clip_feat, guidance), None
|
||||
else:
|
||||
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
|
||||
else:
|
||||
if perform_step:
|
||||
return (
|
||||
grad_accumulate_and_step(
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
return compute_loss_and_accumulate_grads(
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
dataset = load_dataset(args.dataset)
|
||||
trainer = Trainer(flux, dataset, args)
|
||||
trainer.encode_dataset()
|
||||
|
||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
generate_progress_images(0, flux, args)
|
||||
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
print(
|
||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if (i + 1) % args.progress_every == 0:
|
||||
generate_progress_images(i + 1, flux, args)
|
||||
|
||||
if (i + 1) % args.checkpoint_every == 0:
|
||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
tic = time.time()
|
||||
|
||||
save_adapters("final_adapters.safetensors", flux, args)
|
||||
print("Training successful.")
|
||||
16
flux/flux/__init__.py
Normal file
16
flux/flux/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from .datasets import Dataset, load_dataset
|
||||
from .flux import FluxPipeline
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .trainer import Trainer
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_clip_tokenizer,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
save_config,
|
||||
)
|
||||
357
flux/flux/autoencoder.py
Normal file
357
flux/flux/autoencoder.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.upsample import upsample_nearest
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: List[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(
|
||||
num_groups=32,
|
||||
dims=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.q = nn.Linear(in_channels, in_channels)
|
||||
self.k = nn.Linear(in_channels, in_channels)
|
||||
self.v = nn.Linear(in_channels, in_channels)
|
||||
self.proj_out = nn.Linear(in_channels, in_channels)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = x.reshape(B, 1, -1, C)
|
||||
y = self.norm(y)
|
||||
q = self.q(y)
|
||||
k = self.k(y)
|
||||
v = self.v(y)
|
||||
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
|
||||
y = self.proj_out(y)
|
||||
|
||||
return x + y.reshape(B, H, W, C)
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=32,
|
||||
dims=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=32,
|
||||
dims=out_channels,
|
||||
eps=1e-6,
|
||||
affine=True,
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
x = upsample_nearest(x, (2, 2))
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = []
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = []
|
||||
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = {}
|
||||
down["block"] = block
|
||||
down["attn"] = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down["downsample"] = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = {}
|
||||
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid["attn_1"] = AttnBlock(block_in)
|
||||
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(
|
||||
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level]["block"][i_block](hs[-1])
|
||||
|
||||
# TODO: Remove the attn
|
||||
if len(self.down[i_level]["attn"]) > 0:
|
||||
h = self.down[i_level]["attn"][i_block](h)
|
||||
|
||||
hs.append(h)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level]["downsample"](hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid["block_1"](h)
|
||||
h = self.mid["attn_1"](h)
|
||||
h = self.mid["block_2"](h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = {}
|
||||
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid["attn_1"] = AttnBlock(block_in)
|
||||
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = []
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = []
|
||||
attn = [] # TODO: Remove the attn, nobody appends anything to it
|
||||
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = {}
|
||||
up["block"] = block
|
||||
up["attn"] = attn
|
||||
if i_level != 0:
|
||||
up["upsample"] = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(
|
||||
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def __call__(self, z: mx.array):
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid["block_1"](h)
|
||||
h = self.mid["attn_1"](h)
|
||||
h = self.mid["block_2"](h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level]["block"][i_block](h)
|
||||
|
||||
# TODO: Remove the attn
|
||||
if len(self.up[i_level]["attn"]) > 0:
|
||||
h = self.up[i_level]["attn"][i_block](h)
|
||||
|
||||
if i_level != 0:
|
||||
h = self.up[i_level]["upsample"](h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __call__(self, z: mx.array):
|
||||
mean, logvar = mx.split(z, 2, axis=-1)
|
||||
if self.training:
|
||||
std = mx.exp(0.5 * logvar)
|
||||
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
|
||||
return mean + std * eps
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
if w.ndim == 4:
|
||||
w = w.transpose(0, 2, 3, 1)
|
||||
w = w.reshape(-1).reshape(w.shape)
|
||||
if w.shape[1:3] == (1, 1):
|
||||
w = w.squeeze((1, 2))
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def encode(self, x: mx.array):
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: mx.array):
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.decode(self.encode(x))
|
||||
154
flux/flux/clip.py
Normal file
154
flux/flux/clip.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
hidden_act: str = "quick_gelu"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return cls(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_act=config["hidden_act"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput:
|
||||
# The last_hidden_state indexed at the EOS token and possibly projected if
|
||||
# the model has a projection layer
|
||||
pooled_output: Optional[mx.array] = None
|
||||
|
||||
# The full sequence output of the transformer after the final layernorm
|
||||
last_hidden_state: Optional[mx.array] = None
|
||||
|
||||
# A list of hidden states corresponding to the outputs of the transformer layers
|
||||
hidden_states: Optional[List[mx.array]] = None
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
self.act = _ACTIVATIONS[activation]
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
def _get_mask(self, N, dtype):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||
return mask
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for key, w in weights.items():
|
||||
# Remove prefixes
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
new_weights[key] = w
|
||||
|
||||
return new_weights
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
eos_tokens = x.argmax(-1)
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
hidden_states = []
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
hidden_states.append(x)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
x = self.final_layer_norm(x)
|
||||
last_hidden_state = x
|
||||
|
||||
# Select the EOS token
|
||||
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
||||
|
||||
return CLIPOutput(
|
||||
pooled_output=pooled_output,
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
75
flux/flux/datasets.py
Normal file
75
flux/flux/datasets.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __getitem__(self, index: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LocalDataset(Dataset):
|
||||
prompt_key = "prompt"
|
||||
|
||||
def __init__(self, dataset: str, data_file):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(data_file, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
image = Image.open(self.dataset_base / item["image"])
|
||||
return image, item[self.prompt_key]
|
||||
|
||||
|
||||
class LegacyDataset(LocalDataset):
|
||||
prompt_key = "text"
|
||||
|
||||
def __init__(self, dataset: str):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(self.dataset_base / "index.json") as f:
|
||||
self._data = json.load(f)["data"]
|
||||
|
||||
|
||||
class HuggingFaceDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset: str):
|
||||
from datasets import load_dataset as hf_load_dataset
|
||||
|
||||
self._df = hf_load_dataset(dataset)["train"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._df[index]
|
||||
return item["image"], item["prompt"]
|
||||
|
||||
|
||||
def load_dataset(dataset: str):
|
||||
dataset_base = Path(dataset)
|
||||
data_file = dataset_base / "train.jsonl"
|
||||
legacy_file = dataset_base / "index.json"
|
||||
|
||||
if data_file.exists():
|
||||
print(f"Load the local dataset {data_file} .", flush=True)
|
||||
dataset = LocalDataset(dataset, data_file)
|
||||
elif legacy_file.exists():
|
||||
print(f"Load the local dataset {legacy_file} .")
|
||||
print()
|
||||
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
||||
print(" See the README for details.")
|
||||
print(flush=True)
|
||||
dataset = LegacyDataset(dataset)
|
||||
else:
|
||||
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||
dataset = HuggingFaceDataset(dataset)
|
||||
|
||||
return dataset
|
||||
246
flux/flux/flux.py
Normal file
246
flux/flux/flux.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_clip_tokenizer,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str, t5_padding: bool = True):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.t5_padding = t5_padding
|
||||
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.sampler = FluxSampler(name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _denoising_loop(
|
||||
self,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
B = len(x_t)
|
||||
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
def fuse_lora_layers(self):
|
||||
fused_layers = []
|
||||
for name, module in self.flow.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
||||
321
flux/flux/layers.py
Normal file
321
flux/flux/layers.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def _rope(pos: mx.array, dim: int, theta: float):
|
||||
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
x = pos[..., None] * omega
|
||||
cosx = mx.cos(x)
|
||||
sinx = mx.sin(x)
|
||||
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
|
||||
pe = pe.reshape(*pe.shape[:-1], 2, 2)
|
||||
|
||||
return pe
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def _ab_plus_cd(a, b, c, d):
|
||||
return a * b + c * d
|
||||
|
||||
|
||||
def _apply_rope(x, pe):
|
||||
s = x.shape
|
||||
x = x.reshape(*s[:-1], -1, 1, 2)
|
||||
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
|
||||
B, H, L, D = q.shape
|
||||
|
||||
q = _apply_rope(q, pe)
|
||||
k = _apply_rope(k, pe)
|
||||
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
|
||||
|
||||
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
|
||||
def timestep_embedding(
|
||||
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
|
||||
):
|
||||
half = dim // 2
|
||||
freqs = mx.arange(0, half, dtype=mx.float32) / half
|
||||
freqs = freqs * (-math.log(max_period))
|
||||
freqs = mx.exp(freqs)
|
||||
|
||||
x = (time_factor * t)[:, None] * freqs[None]
|
||||
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
|
||||
|
||||
return x.astype(t.dtype)
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def __call__(self, ids: mx.array):
|
||||
n_axes = ids.shape[-1]
|
||||
pe = mx.concatenate(
|
||||
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
axis=-3,
|
||||
)
|
||||
|
||||
return pe[:, None]
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.out_layer(nn.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = nn.RMSNorm(dim)
|
||||
self.key_norm = nn.RMSNorm(dim)
|
||||
|
||||
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
|
||||
return self.query_norm(q), self.key_norm(k)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
|
||||
H = self.num_heads
|
||||
B, L, _ = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
q, k = self.norm(q, k)
|
||||
x = _attention(q, k, v, pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: mx.array
|
||||
scale: mx.array
|
||||
gate: mx.array
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
|
||||
x = self.lin(nn.silu(x))
|
||||
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
|
||||
|
||||
mod1 = ModulationOut(*xs[:3])
|
||||
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
|
||||
|
||||
return mod1, mod2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(
|
||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||
)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approx="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approx="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.sharding_group = None
|
||||
|
||||
def __call__(
|
||||
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
B, L, _ = img.shape
|
||||
_, S, _ = txt.shape
|
||||
H = self.num_heads
|
||||
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
|
||||
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
|
||||
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
|
||||
|
||||
# run actual attention
|
||||
q = mx.concatenate([txt_q, img_q], axis=2)
|
||||
k = mx.concatenate([txt_k, img_k], axis=2)
|
||||
v = mx.concatenate([txt_v, img_v], axis=2)
|
||||
|
||||
attn = _attention(q, k, v, pe)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# Project - cat - average - split
|
||||
txt_attn = self.txt_attn.proj(txt_attn)
|
||||
img_attn = self.img_attn.proj(img_attn)
|
||||
if self.sharding_group is not None:
|
||||
attn = mx.concatenate([txt_attn, img_attn], axis=1)
|
||||
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * img_attn
|
||||
img_mlp = self.img_mlp(
|
||||
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * txt_attn
|
||||
txt_mlp = self.txt_mlp(
|
||||
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
||||
)
|
||||
|
||||
if self.sharding_group is not None:
|
||||
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
|
||||
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
|
||||
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
|
||||
|
||||
# finalize the img/txt blocks
|
||||
img = img + img_mod2.gate * img_mlp
|
||||
txt = txt + txt_mod2.gate * txt_mlp
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approx="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
|
||||
B, L, _ = x.shape
|
||||
H = self.num_heads
|
||||
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
|
||||
q, k, v, mlp = mx.split(
|
||||
self.linear1(x_mod),
|
||||
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
|
||||
axis=-1,
|
||||
)
|
||||
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
|
||||
q, k = self.norm(q, k)
|
||||
|
||||
# compute attention
|
||||
y = _attention(q, k, v, pe)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
|
||||
return x + mod.gate * y
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(
|
||||
hidden_size, patch_size * patch_size * out_channels, bias=True
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, vec: mx.array):
|
||||
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
@@ -6,47 +6,40 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class DoRALinear(nn.Module):
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(
|
||||
def from_base(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
# TODO support quantized weights in DoRALinear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
raise ValueError("DoRALinear does not yet support quantization.")
|
||||
dora_lin = DoRALinear(
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
dora_lin.linear = linear
|
||||
return dora_lin
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def to_linear(self, de_quantize: bool = False):
|
||||
def fuse(self):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
weight = weight + lora_b @ lora_a
|
||||
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
|
||||
fused_linear.weight = norm_scale[:, None] * weight
|
||||
|
||||
lora_b = self.scale * self.lora_b.T
|
||||
lora_a = self.lora_a.T
|
||||
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
@@ -55,13 +48,14 @@ class DoRALinear(nn.Module):
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
scale: float = 1.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
@@ -75,21 +69,8 @@ class DoRALinear(nn.Module):
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
||||
|
||||
def __call__(self, x):
|
||||
# Regular LoRA (without a bias)
|
||||
y = x @ self.linear.weight.T
|
||||
y = self.linear(x)
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
out = y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights
|
||||
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m / denom) * out
|
||||
|
||||
if "bias" in self.linear:
|
||||
out = out + self.linear.bias
|
||||
return out
|
||||
return y + (self.scale * z).astype(x.dtype)
|
||||
178
flux/flux/model.py
Normal file
178
flux/flux/model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(
|
||||
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
|
||||
)
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(
|
||||
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
|
||||
)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
if params.guidance_embed
|
||||
else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = [
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
|
||||
self.single_blocks = [
|
||||
SingleStreamBlock(
|
||||
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
if k.startswith("model.diffusion_model."):
|
||||
k = k[22:]
|
||||
if k.endswith(".scale"):
|
||||
k = k[:-6] + ".weight"
|
||||
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
|
||||
if f".{seq}." in k:
|
||||
k = k.replace(f".{seq}.", f".{seq}.layers.")
|
||||
break
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
if N == 1:
|
||||
return
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.num_heads //= N
|
||||
block.img_attn.num_heads //= N
|
||||
block.txt_attn.num_heads //= N
|
||||
block.sharding_group = group
|
||||
block.img_attn.qkv = shard_linear(
|
||||
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||
)
|
||||
block.txt_attn.qkv = shard_linear(
|
||||
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||
)
|
||||
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
||||
block.img_mlp.layers[0] = shard_linear(
|
||||
block.img_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.txt_mlp.layers[0] = shard_linear(
|
||||
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.num_heads //= N
|
||||
block.hidden_size //= N
|
||||
block.linear1 = shard_linear(
|
||||
block.linear1,
|
||||
"all-to-sharded",
|
||||
segments=[1 / 7, 2 / 7, 3 / 7],
|
||||
group=group,
|
||||
)
|
||||
block.linear2 = shard_linear(
|
||||
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img: mx.array,
|
||||
img_ids: mx.array,
|
||||
txt: mx.array,
|
||||
txt_ids: mx.array,
|
||||
timesteps: mx.array,
|
||||
y: mx.array,
|
||||
guidance: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError(
|
||||
"Didn't get guidance strength for guidance distilled model."
|
||||
)
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = mx.concatenate([txt_ids, img_ids], axis=1)
|
||||
pe = self.pe_embedder(ids).astype(img.dtype)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = mx.concatenate([txt, img], axis=1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
return img
|
||||
57
flux/flux/sampler.py
Normal file
57
flux/flux/sampler.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class FluxSampler:
|
||||
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
|
||||
self._base_shift = base_shift
|
||||
self._max_shift = max_shift
|
||||
self._schnell = "schnell" in name
|
||||
|
||||
def _time_shift(self, x, t):
|
||||
x1, x2 = 256, 4096
|
||||
t1, t2 = self._base_shift, self._max_shift
|
||||
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
|
||||
t = exp_mu / (exp_mu + (1 / t - 1))
|
||||
return t
|
||||
|
||||
@lru_cache
|
||||
def timesteps(
|
||||
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
|
||||
):
|
||||
t = mx.linspace(start, stop, num_steps + 1)
|
||||
|
||||
if not self._schnell:
|
||||
t = self._time_shift(image_sequence_length, t)
|
||||
|
||||
return t.tolist()
|
||||
|
||||
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
|
||||
if self._schnell:
|
||||
# TODO: Should we upweigh 1 and 0.75?
|
||||
t = mx.random.randint(1, 5, shape=(B,), key=key)
|
||||
t = t.astype(dtype) / 4
|
||||
else:
|
||||
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
|
||||
t = self._time_shift(L, t)
|
||||
|
||||
return t
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
return mx.random.normal(shape, dtype=dtype, key=key)
|
||||
|
||||
def add_noise(self, x, t, noise=None, key=None):
|
||||
noise = (
|
||||
noise
|
||||
if noise is not None
|
||||
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
|
||||
)
|
||||
t = t.reshape([-1] + [1] * (x.ndim - 1))
|
||||
return x * (1 - t) + t * noise
|
||||
|
||||
def step(self, pred, x_t, t, t_prev):
|
||||
return x_t + (t_prev - t) * pred
|
||||
244
flux/flux/t5.py
Normal file
244
flux/flux/t5.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
_SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
_ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class T5Config:
|
||||
vocab_size: int
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
relative_attention_num_buckets: int
|
||||
d_kv: int
|
||||
d_model: int
|
||||
feed_forward_proj: str
|
||||
tie_word_embeddings: bool
|
||||
|
||||
d_ff: Optional[int] = None
|
||||
num_decoder_layers: Optional[int] = None
|
||||
relative_attention_max_distance: int = 128
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
num_layers=config["num_layers"],
|
||||
num_heads=config["num_heads"],
|
||||
relative_attention_num_buckets=config["relative_attention_num_buckets"],
|
||||
d_kv=config["d_kv"],
|
||||
d_model=config["d_model"],
|
||||
feed_forward_proj=config["feed_forward_proj"],
|
||||
tie_word_embeddings=config["tie_word_embeddings"],
|
||||
d_ff=config.get("d_ff", 4 * config["d_model"]),
|
||||
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
|
||||
relative_attention_max_distance=config.get(
|
||||
"relative_attention_max_distance", 128
|
||||
),
|
||||
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
|
||||
)
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: T5Config, bidirectional: bool):
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
|
||||
num_buckets = num_buckets // 2 if bidirectional else num_buckets
|
||||
max_exact = num_buckets // 2
|
||||
|
||||
abspos = rpos.abs()
|
||||
is_small = abspos < max_exact
|
||||
|
||||
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
|
||||
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
||||
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
||||
|
||||
buckets = mx.where(is_small, abspos, buckets_large)
|
||||
if bidirectional:
|
||||
buckets = buckets + (rpos > 0) * num_buckets
|
||||
else:
|
||||
buckets = buckets * (rpos < 0)
|
||||
|
||||
return buckets
|
||||
|
||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
|
||||
# shape (query_length, key_length)
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(relative_position_bucket)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
queries: mx.array,
|
||||
keys: mx.array,
|
||||
values: mx.array,
|
||||
mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
values_hat = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
|
||||
)
|
||||
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
self.act = nn.gelu
|
||||
elif activation == "silu":
|
||||
self.act = nn.silu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
def __call__(self, x):
|
||||
if self.gated:
|
||||
hidden_act = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_act * hidden_linear
|
||||
else:
|
||||
x = self.act(self.wi(x))
|
||||
return self.wo(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y, _ = self.attention(y, y, y, mask=mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.dense(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
pos_bias = pos_bias.astype(x.dtype)
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for k, w in weights.items():
|
||||
for old, new in _SHARED_REPLACEMENT_PATTERNS:
|
||||
k = k.replace(old, new)
|
||||
if k.startswith("encoder."):
|
||||
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
|
||||
k = k.replace(old, new)
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def __call__(self, inputs: mx.array):
|
||||
return self.encoder(self.wte(inputs))
|
||||
185
flux/flux/tokenizers.py
Normal file
185
flux/flux/tokenizers.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import regex
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
class CLIPTokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab, max_length=77):
|
||||
self.max_length = max_length
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
if len(tokens) > self.max_length:
|
||||
tokens = tokens[: self.max_length]
|
||||
if append_eos:
|
||||
tokens[-1] = self.eos_token
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text):
|
||||
if not isinstance(text, list):
|
||||
return self.encode([text])
|
||||
|
||||
tokens = self.tokenize(text)
|
||||
length = max(len(t) for t in tokens)
|
||||
for t in tokens:
|
||||
t.extend([self.eos_token] * (length - len(t)))
|
||||
|
||||
return mx.array(tokens)
|
||||
|
||||
|
||||
class T5Tokenizer:
|
||||
def __init__(self, model_file, max_length=512):
|
||||
self._tokenizer = SentencePieceProcessor(model_file)
|
||||
self.max_length = max_length
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.pad_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self._tokenizer.pad_id()
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.bos_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self._tokenizer.bos_id()
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.eos_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self._tokenizer.eos_id()
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
|
||||
|
||||
tokens = self._tokenizer.encode(text)
|
||||
|
||||
if prepend_bos and self.bos_token >= 0:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos and self.eos_token >= 0:
|
||||
tokens.append(self.eos_token)
|
||||
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
|
||||
tokens += [self.pad_token] * (self.max_length - len(tokens))
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, text, pad=True):
|
||||
if not isinstance(text, list):
|
||||
return self.encode([text], pad=pad)
|
||||
|
||||
pad_token = self.pad_token if self.pad_token >= 0 else 0
|
||||
tokens = self.tokenize(text, pad=pad)
|
||||
length = max(len(t) for t in tokens)
|
||||
for t in tokens:
|
||||
t.extend([pad_token] * (length - len(t)))
|
||||
|
||||
return mx.array(tokens)
|
||||
98
flux/flux/trainer.py
Normal file
98
flux/flux/trainer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from tqdm import tqdm
|
||||
|
||||
from .datasets import Dataset
|
||||
from .flux import FluxPipeline
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
||||
self.flux = flux
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * b) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * c,
|
||||
pan[1] * d,
|
||||
crop_size[0] + pan[0] * c,
|
||||
crop_size[1] + pan[1] * d,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
||||
for i in range(num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def _encode_prompt(self, prompt):
|
||||
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def encode_dataset(self):
|
||||
"""Encode the images & prompt in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
||||
self._encode_image(image, self.args.num_augmentations)
|
||||
self._encode_prompt(prompt)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i : i + batch_size]
|
||||
c_i = c_indices[i : i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
||||
230
flux/flux/utils.py
Normal file
230
flux/flux/utils.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .clip import CLIPTextModel, CLIPTextModelConfig
|
||||
from .model import Flux, FluxParams
|
||||
from .t5 import T5Config, T5Encoder
|
||||
from .tokenizers import CLIPTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: Optional[str]
|
||||
ae_path: Optional[str]
|
||||
repo_id: Optional[str]
|
||||
repo_flow: Optional[str]
|
||||
repo_ae: Optional[str]
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-dev.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_flow="flux1-schnell.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def load_flow_model(name: str, hf_download: bool = True):
|
||||
# Get the safetensors file to load
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
|
||||
# Download if needed
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
||||
|
||||
# Make the model
|
||||
model = Flux(configs[name].params)
|
||||
|
||||
# Load the checkpoint if needed
|
||||
if ckpt_path is not None:
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_ae(name: str, hf_download: bool = True):
|
||||
# Get the safetensors file to load
|
||||
ckpt_path = configs[name].ae_path
|
||||
|
||||
# Download if needed
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_ae is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
|
||||
|
||||
# Make the autoencoder
|
||||
ae = AutoEncoder(configs[name].ae_params)
|
||||
|
||||
# Load the checkpoint if needed
|
||||
if ckpt_path is not None:
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = ae.sanitize(weights)
|
||||
ae.load_weights(list(weights.items()))
|
||||
|
||||
return ae
|
||||
|
||||
|
||||
def load_clip(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
|
||||
with open(config_path) as f:
|
||||
config = CLIPTextModelConfig.from_dict(json.load(f))
|
||||
|
||||
# Make the clip text encoder
|
||||
clip = CLIPTextModel(config)
|
||||
|
||||
# Load the weights
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
|
||||
weights = mx.load(ckpt_path)
|
||||
weights = clip.sanitize(weights)
|
||||
clip.load_weights(list(weights.items()))
|
||||
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5(name: str):
|
||||
# Load the config
|
||||
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
|
||||
with open(config_path) as f:
|
||||
config = T5Config.from_dict(json.load(f))
|
||||
|
||||
# Make the T5 model
|
||||
t5 = T5Encoder(config)
|
||||
|
||||
# Load the weights
|
||||
model_index = hf_hub_download(
|
||||
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
|
||||
)
|
||||
weight_files = set()
|
||||
with open(model_index) as f:
|
||||
for _, w in json.load(f)["weight_map"].items():
|
||||
weight_files.add(w)
|
||||
weights = {}
|
||||
for w in weight_files:
|
||||
w = f"text_encoder_2/{w}"
|
||||
w = hf_hub_download(configs[name].repo_id, w)
|
||||
weights.update(mx.load(w))
|
||||
weights = t5.sanitize(weights)
|
||||
t5.load_weights(list(weights.items()))
|
||||
|
||||
return t5
|
||||
|
||||
|
||||
def load_clip_tokenizer(name: str):
|
||||
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
|
||||
|
||||
|
||||
def load_t5_tokenizer(name: str, pad: bool = True):
|
||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
||||
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
||||
|
||||
|
||||
def save_config(
|
||||
config: dict,
|
||||
config_path: Union[str, Path],
|
||||
) -> None:
|
||||
"""Save the model configuration to the ``config_path``.
|
||||
|
||||
The final configuration will be sorted before saving for better readability.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
config_path (Union[str, Path]): Model configuration file path.
|
||||
"""
|
||||
# Sort the config for better readability
|
||||
config = dict(sorted(config.items()))
|
||||
|
||||
# Write the config to the provided file
|
||||
with open(config_path, "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
109
flux/generate_interactive.py
Normal file
109
flux/generate_interactive.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
def print_zero(group, *args, **kwargs):
|
||||
if group.rank() == 0:
|
||||
flush = kwargs.pop("flush", True)
|
||||
print(*args, **kwargs, flush=flush)
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
def to_latent_size(image_size):
|
||||
h, w = image_size
|
||||
h = ((h + 15) // 16) * 16
|
||||
w = ((w + 15) // 16) * 16
|
||||
|
||||
if (h, w) != image_size:
|
||||
print(
|
||||
"Warning: The image dimensions need to be divisible by 16px. "
|
||||
f"Changing size to {h}x{w}."
|
||||
)
|
||||
|
||||
return (h // 8, w // 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
|
||||
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
group = mx.distributed.init()
|
||||
if group.size() > 1:
|
||||
flux.flow.shard(group)
|
||||
|
||||
print_zero(group, "Loading models")
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
def print_help():
|
||||
print_zero(group, "The command list:")
|
||||
print_zero(group, "- 'q' to exit")
|
||||
print_zero(group, "- 's HxW' to change the size of the image")
|
||||
print_zero(group, "- 'n S' to change the number of steps")
|
||||
print_zero(group, "- 'h' to print this help")
|
||||
|
||||
print_zero(group, "FLUX interactive session")
|
||||
print_help()
|
||||
seed = 0
|
||||
size = (512, 512)
|
||||
latent_size = to_latent_size(size)
|
||||
steps = 50 if args.model == "dev" else 4
|
||||
while True:
|
||||
prompt = input(">> " if group.rank() == 0 else "")
|
||||
if prompt == "q":
|
||||
break
|
||||
if prompt == "h":
|
||||
print_help()
|
||||
continue
|
||||
if prompt.startswith("s "):
|
||||
size = tuple([int(xi) for xi in prompt[2:].split("x")])
|
||||
print_zero(group, "Setting the size to", size)
|
||||
latent_size = to_latent_size(size)
|
||||
continue
|
||||
if prompt.startswith("n "):
|
||||
steps = int(prompt[2:])
|
||||
print_zero(group, "Setting the steps to", steps)
|
||||
continue
|
||||
|
||||
seed += 1
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=1,
|
||||
num_steps=steps,
|
||||
latent_size=latent_size,
|
||||
guidance=4.0,
|
||||
seed=seed,
|
||||
)
|
||||
print_zero(group, "Processing prompt")
|
||||
mx.eval(next(latents))
|
||||
print_zero(group, "Generating latents")
|
||||
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
|
||||
mx.eval(xt)
|
||||
print_zero(group, "Generating image")
|
||||
xt = flux.decode(xt, latent_size)
|
||||
xt = (xt * 255).astype(mx.uint8)
|
||||
mx.eval(xt)
|
||||
im = Image.fromarray(np.array(xt[0]))
|
||||
im.save(args.output)
|
||||
print_zero(group, "Saved at", args.output, end="\n\n")
|
||||
7
flux/requirements.txt
Normal file
7
flux/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
mlx>=0.18.1
|
||||
huggingface-hub
|
||||
regex
|
||||
numpy
|
||||
tqdm
|
||||
Pillow
|
||||
sentencepiece
|
||||
BIN
flux/static/dog-r4-g8-1200-512x1024.png
Normal file
BIN
flux/static/dog-r4-g8-1200-512x1024.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 754 KiB |
BIN
flux/static/dog-r4-g8-1200.png
Normal file
BIN
flux/static/dog-r4-g8-1200.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 423 KiB |
BIN
flux/static/dog6.png
Normal file
BIN
flux/static/dog6.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 434 KiB |
BIN
flux/static/generated-mlx.png
Normal file
BIN
flux/static/generated-mlx.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 153 KiB |
175
flux/txt2image.py
Normal file
175
flux/txt2image.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
def to_latent_size(image_size):
|
||||
h, w = image_size
|
||||
h = ((h + 15) // 16) * 16
|
||||
w = ((w + 15) // 16) * 16
|
||||
|
||||
if (h, w) != image_size:
|
||||
print(
|
||||
"Warning: The image dimensions need to be divisible by 16px. "
|
||||
f"Changing size to {h}x{w}."
|
||||
)
|
||||
|
||||
return (h // 8, w // 8)
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
def load_adapter(flux, adapter_file, fuse=False):
|
||||
weights, lora_config = mx.load(adapter_file, return_metadata=True)
|
||||
rank = int(lora_config["lora_rank"])
|
||||
num_blocks = int(lora_config["lora_blocks"])
|
||||
flux.linear_to_lora_layers(rank, num_blocks)
|
||||
flux.flow.load_weights(list(weights.items()), strict=False)
|
||||
if fuse:
|
||||
flux.fuse_lora_layers()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--n-images", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
|
||||
)
|
||||
parser.add_argument("--steps", type=int)
|
||||
parser.add_argument("--guidance", type=float, default=4.0)
|
||||
parser.add_argument("--n-rows", type=int, default=1)
|
||||
parser.add_argument("--decoding-batch-size", type=int, default=1)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--preload-models", action="store_true")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
parser.add_argument("--save-raw", action="store_true")
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--verbose", "-v", action="store_true")
|
||||
parser.add_argument("--adapter")
|
||||
parser.add_argument("--fuse-adapter", action="store_true")
|
||||
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
||||
parser.add_argument("--force-shard", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
|
||||
args.steps = args.steps or (50 if args.model == "dev" else 2)
|
||||
|
||||
if args.adapter:
|
||||
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
|
||||
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
# Figure out what kind of distributed generation we should do
|
||||
group = mx.distributed.init()
|
||||
n_images = args.n_images
|
||||
should_gather = False
|
||||
if group.size() > 1:
|
||||
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
||||
flux.flow.shard(group)
|
||||
else:
|
||||
n_images //= group.size()
|
||||
should_gather = True
|
||||
|
||||
# If we are sharding we should have the same seed and if we are doing
|
||||
# data parallel generation we should have different seeds
|
||||
if args.seed is None:
|
||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||
if should_gather:
|
||||
args.seed = args.seed + group.rank()
|
||||
|
||||
if args.preload_models:
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
# Make the generator
|
||||
latent_size = to_latent_size(args.image_size)
|
||||
latents = flux.generate_latents(
|
||||
args.prompt,
|
||||
n_images=n_images,
|
||||
num_steps=args.steps,
|
||||
latent_size=latent_size,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# First we get and eval the conditioning
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
|
||||
mx.reset_peak_memory()
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the text encoders.
|
||||
del flux.t5
|
||||
del flux.clip
|
||||
|
||||
# Actual denoising loop
|
||||
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
|
||||
mx.eval(x_t)
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the flow transformer.
|
||||
del flux.flow
|
||||
peak_mem_generation = mx.get_peak_memory() / 1024**3
|
||||
mx.reset_peak_memory()
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
|
||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
||||
peak_mem_overall = max(
|
||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
||||
)
|
||||
|
||||
# Gather them if each node has different images
|
||||
decoded = mx.concatenate(decoded, axis=0)
|
||||
if should_gather:
|
||||
decoded = mx.distributed.all_gather(decoded)
|
||||
mx.eval(decoded)
|
||||
|
||||
if args.save_raw:
|
||||
*name, suffix = args.output.split(".")
|
||||
name = ".".join(name)
|
||||
x = decoded
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
for i in range(len(x)):
|
||||
im = Image.fromarray(np.array(x[i]))
|
||||
im.save(".".join([name, str(i), suffix]))
|
||||
else:
|
||||
# Arrange them on a grid
|
||||
x = decoded
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(args.output)
|
||||
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose and group.rank() == 0:
|
||||
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
||||
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
||||
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||
@@ -79,10 +79,10 @@ def load_image(image_source):
|
||||
def prepare_inputs(processor, image, prompt):
|
||||
if isinstance(image, str):
|
||||
image = load_image(image)
|
||||
inputs = processor(prompt, image, return_tensors="np")
|
||||
inputs = processor(image, prompt, return_tensors="np")
|
||||
pixel_values = mx.array(inputs["pixel_values"])
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
return input_ids, pixel_values
|
||||
return pixel_values, input_ids
|
||||
|
||||
|
||||
def load_model(model_path, tokenizer_config={}):
|
||||
@@ -126,8 +126,7 @@ def main():
|
||||
processor, model = load_model(args.model, tokenizer_config)
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
|
||||
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
|
||||
|
||||
print(prompt)
|
||||
generated_text = generate_text(
|
||||
|
||||
@@ -68,11 +68,10 @@ class LlavaModel(nn.Module):
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
):
|
||||
if pixel_values is None:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
# Get the input embeddings from the language model
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
if pixel_values is None:
|
||||
return inputs_embeds
|
||||
|
||||
# Get the ouptut hidden states from the vision model
|
||||
*_, hidden_states = self.vision_tower(
|
||||
@@ -105,31 +104,21 @@ class LlavaModel(nn.Module):
|
||||
self, image_features, inputs_embeds, input_ids
|
||||
):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, num_image_patches, embed_dim = image_features.shape
|
||||
|
||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||
image_positions = mx.array(
|
||||
np.where(input_ids[0] == image_token_index)[0], mx.uint32
|
||||
)
|
||||
|
||||
if len(image_positions) != num_images:
|
||||
if len(image_positions) != num_image_patches:
|
||||
raise ValueError(
|
||||
f"The number of image tokens ({len(image_positions)}) does not "
|
||||
f" match the number of image inputs ({num_images})."
|
||||
f" match the number of image patches ({num_image_patches})."
|
||||
)
|
||||
|
||||
text_segments = []
|
||||
start_idx = 0
|
||||
|
||||
for position in image_positions:
|
||||
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||
start_idx = position + 1
|
||||
|
||||
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||
|
||||
# Create a final embedding of shape
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
inputs_embeds[0, image_positions] = image_features
|
||||
return inputs_embeds
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
# Contributing to MLX LM
|
||||
|
||||
Below are some tips to port LLMs available on Hugging Face to MLX.
|
||||
|
||||
Before starting checkout the [general contribution
|
||||
guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md).
|
||||
|
||||
Next, from this directory, do an editable install:
|
||||
|
||||
```shell
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then check if the model has weights in the
|
||||
[safetensors](https://huggingface.co/docs/safetensors/index) format. If not
|
||||
[follow instructions](https://huggingface.co/spaces/safetensors/convert) to
|
||||
convert it.
|
||||
|
||||
After that, add the model file to the
|
||||
[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models)
|
||||
directory. You can see other examples there. We recommend starting from a model
|
||||
that is similar to the model you are porting.
|
||||
|
||||
Make sure the name of the new model file is the same as the `model_type` in the
|
||||
`config.json`, for example
|
||||
[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17).
|
||||
|
||||
To determine the model layer names, we suggest either:
|
||||
|
||||
- Refer to the Transformers implementation if you are familiar with the
|
||||
codebase.
|
||||
- Load the model weights and check the weight names which will tell you about
|
||||
the model structure.
|
||||
- Look at the names of the weights by inspecting `model.safetensors.index.json`
|
||||
in the Hugging Face repo.
|
||||
|
||||
To add LoRA support edit
|
||||
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
|
||||
|
||||
Finally, add a test for the new modle type to the [model
|
||||
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
|
||||
|
||||
From the `llms/` directory, you can run the tests with:
|
||||
|
||||
```shell
|
||||
python -m unittest discover tests/
|
||||
```
|
||||
@@ -1,2 +0,0 @@
|
||||
include mlx_lm/requirements.txt
|
||||
recursive-include mlx_lm/ *.py
|
||||
173
llms/README.md
173
llms/README.md
@@ -1,171 +1,6 @@
|
||||
## Generate Text with LLMs and MLX
|
||||
# MOVE NOTICE
|
||||
|
||||
The easiest way to get started is to install the `mlx-lm` package:
|
||||
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```sh
|
||||
pip install mlx-lm
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
|
||||
```sh
|
||||
conda install -c conda-forge mlx-lm
|
||||
```
|
||||
|
||||
The `mlx-lm` package also has:
|
||||
|
||||
- [LoRA and QLoRA fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
|
||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||
|
||||
### Python API
|
||||
|
||||
You can use `mlx-lm` as a module:
|
||||
|
||||
```python
|
||||
from mlx_lm import load, generate
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
||||
```
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
|
||||
```
|
||||
>>> help(generate)
|
||||
```
|
||||
|
||||
Check out the [generation example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py) to see how to use the API in more detail.
|
||||
|
||||
The `mlx-lm` package also comes with functionality to quantize and optionally
|
||||
upload models to the Hugging Face Hub.
|
||||
|
||||
You can convert models in the Python API with:
|
||||
|
||||
```python
|
||||
from mlx_lm import convert
|
||||
|
||||
repo = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
|
||||
|
||||
convert(repo, quantize=True, upload_repo=upload_repo)
|
||||
```
|
||||
|
||||
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
|
||||
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
|
||||
converted model in the path `mlx_model` by default.
|
||||
|
||||
To see a description of all the arguments you can do:
|
||||
|
||||
```
|
||||
>>> help(convert)
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
|
||||
For streaming generation, use the `stream_generate` function. This returns a
|
||||
generator object which streams the output text. For example,
|
||||
|
||||
```python
|
||||
from mlx_lm import load, stream_generate
|
||||
|
||||
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
model, tokenizer = load(repo)
|
||||
|
||||
prompt = "Write a story about Einstein"
|
||||
|
||||
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||
print(t, end="", flush=True)
|
||||
print()
|
||||
```
|
||||
|
||||
### Command Line
|
||||
|
||||
You can also use `mlx-lm` from the command line with:
|
||||
|
||||
```
|
||||
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
|
||||
```
|
||||
|
||||
This will download a Mistral 7B model from the Hugging Face Hub and generate
|
||||
text using the given prompt.
|
||||
|
||||
For a full list of options run:
|
||||
|
||||
```
|
||||
mlx_lm.generate --help
|
||||
```
|
||||
|
||||
To quantize a model from the command line run:
|
||||
|
||||
```
|
||||
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
|
||||
```
|
||||
|
||||
For more options run:
|
||||
|
||||
```
|
||||
mlx_lm.convert --help
|
||||
```
|
||||
|
||||
You can upload new models to Hugging Face by specifying `--upload-repo` to
|
||||
`convert`. For example, to upload a quantized Mistral-7B model to the
|
||||
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
|
||||
|
||||
```
|
||||
mlx_lm.convert \
|
||||
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
-q \
|
||||
--upload-repo mlx-community/my-4bit-mistral
|
||||
```
|
||||
|
||||
### Supported Models
|
||||
|
||||
The example supports Hugging Face format Mistral, Llama, and Phi-2 style
|
||||
models. If the model you want to run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
|
||||
Here are a few examples of Hugging Face models that work with this example:
|
||||
|
||||
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
|
||||
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
|
||||
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
|
||||
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
|
||||
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
|
||||
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
|
||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||
|
||||
Most
|
||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
|
||||
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
|
||||
and
|
||||
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
|
||||
style models should work out of the box.
|
||||
|
||||
For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
|
||||
enable the `trust_remote_code` option. You can do this by passing
|
||||
`--trust-remote-code` in the command line. If you don't specify the flag
|
||||
explicitly, you will be prompted to trust remote code in the terminal when
|
||||
running the model.
|
||||
|
||||
For `Qwen` models you must also specify the `eos_token`. You can do this by
|
||||
passing `--eos-token "<|endoftext|>"` in the command
|
||||
line.
|
||||
|
||||
These options can also be set in the Python API. For example:
|
||||
|
||||
```python
|
||||
model, tokenizer = load(
|
||||
"qwen/Qwen-7B",
|
||||
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||
)
|
||||
```
|
||||
The package has been removed from the MLX Examples repo. Send new contributions
|
||||
and issues to the MLX LM repo.
|
||||
|
||||
@@ -40,7 +40,7 @@ def generate(
|
||||
if len(tokens) == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
prompt_tps = len(prompt) / prompt_time
|
||||
gen_tps = (len(tokens) - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
@@ -19,10 +19,10 @@ class ModelArgs:
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
context_length: int
|
||||
num_key_value_heads: int = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
model_type: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -54,7 +54,7 @@ class Attention(nn.Module):
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
@@ -66,7 +66,7 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
1 / float(args.rope_scaling["factor"])
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
@@ -254,7 +254,7 @@ def translate_weight_names(name):
|
||||
return name
|
||||
|
||||
|
||||
def load(gguf_file: str, repo: str = None):
|
||||
def load(gguf_file: str, repo: Optional[str] = None):
|
||||
# If the gguf_file exists, try to load model from it.
|
||||
# Otherwise try to download and cache from the HF repo
|
||||
if not Path(gguf_file).exists():
|
||||
|
||||
@@ -7,6 +7,7 @@ import glob
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -149,7 +150,8 @@ def quantize(weights, config, args):
|
||||
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
|
||||
max_file_size_bytes = max_file_size_gibibyte << 30
|
||||
shards = []
|
||||
shard, shard_size = {}, 0
|
||||
shard: Dict[str, mx.array] = {}
|
||||
shard_size = 0
|
||||
for k, v in weights.items():
|
||||
if shard_size + v.nbytes > max_file_size_bytes:
|
||||
shards.append(shard)
|
||||
|
||||
@@ -23,7 +23,7 @@ class ModelArgs:
|
||||
n_kv_heads: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
moe: dict = None
|
||||
moe: dict
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -91,7 +91,6 @@ class FeedForward(nn.Module):
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.num_experts = args.moe["num_experts"]
|
||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||
@@ -115,7 +114,6 @@ class MOEFeedForward(nn.Module):
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt[None, :])
|
||||
y = mx.concatenate(y)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
# Fine-Tuning with LoRA or QLoRA
|
||||
|
||||
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
|
||||
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
|
||||
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
|
||||
- Mistral
|
||||
- Llama
|
||||
- Phi2
|
||||
- Mixtral
|
||||
- Qwen2
|
||||
- Gemma
|
||||
- OLMo
|
||||
- MiniCPM
|
||||
- InternLM2
|
||||
|
||||
## Contents
|
||||
|
||||
- [Run](#Run)
|
||||
- [Fine-tune](#Fine-tune)
|
||||
- [Evaluate](#Evaluate)
|
||||
- [Generate](#Generate)
|
||||
- [Fuse](#Fuse)
|
||||
- [Data](#Data)
|
||||
- [Memory Issues](#Memory-Issues)
|
||||
|
||||
## Run
|
||||
|
||||
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora --help
|
||||
```
|
||||
|
||||
Note, in the following the `--model` argument can be any compatible Hugging
|
||||
Face repo or a local path to a converted model.
|
||||
|
||||
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
|
||||
[example YAML](examples/lora_config.yaml). For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
If command-line flags are also used, they will override the corresponding
|
||||
values in the config.
|
||||
|
||||
### Fine-tune
|
||||
|
||||
To fine-tune a model use:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--train \
|
||||
--data <path_to_data> \
|
||||
--iters 600
|
||||
```
|
||||
|
||||
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
|
||||
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
|
||||
details on the data format see the section on [Data](#Data).
|
||||
|
||||
For example, to fine-tune a Mistral 7B you can use `--model
|
||||
mistralai/Mistral-7B-v0.1`.
|
||||
|
||||
If `--model` points to a quantized model, then the training will use QLoRA,
|
||||
otherwise it will use regular LoRA.
|
||||
|
||||
By default, the adapter config and weights are saved in `adapters/`. You can
|
||||
specify the output location with `--adapter-path`.
|
||||
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--data <path_to_data> \
|
||||
--test
|
||||
```
|
||||
|
||||
### Generate
|
||||
|
||||
For generation use `mlx_lm.generate`:
|
||||
|
||||
```shell
|
||||
mlx_lm.generate \
|
||||
--model <path_to_model> \
|
||||
--adapter-path <path_to_adapters> \
|
||||
--prompt "<your_model_prompt>"
|
||||
```
|
||||
|
||||
## Fuse
|
||||
|
||||
You can generate a model fused with the low-rank adapters using the
|
||||
`mlx_lm.fuse` command. This command also allows you to optionally:
|
||||
|
||||
- Upload the fused model to the Hugging Face Hub.
|
||||
- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
|
||||
Mixtral, and Llama style models in fp16 precision.
|
||||
|
||||
To see supported options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse --help
|
||||
```
|
||||
|
||||
To generate the fused model run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse --model <path_to_model>
|
||||
```
|
||||
|
||||
This will by default load the adapters from `adapters/`, and save the fused
|
||||
model in the path `lora_fused_model/`. All of these are configurable.
|
||||
|
||||
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
|
||||
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
|
||||
useful for the sake of attribution and model versioning.
|
||||
|
||||
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--upload-repo mlx-community/my-lora-mistral-7b \
|
||||
--hf-path mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
To export a fused model to GGUF, run:
|
||||
|
||||
```shell
|
||||
mlx_lm.fuse \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--export-gguf
|
||||
```
|
||||
|
||||
This will save the GGUF model in `lora_fused_model/ggml-model-f16.gguf`. You
|
||||
can specify the file name with `--gguf-path`.
|
||||
|
||||
## Data
|
||||
|
||||
The LoRA command expects you to provide a dataset with `--data`. The MLX
|
||||
Examples GitHub repo has an [example of the WikiSQL
|
||||
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||
correct format.
|
||||
|
||||
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
|
||||
Face.
|
||||
|
||||
### Local Datasets
|
||||
|
||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||
loader expects a `test.jsonl` in the data directory.
|
||||
|
||||
Currently, `*.jsonl` files support three data formats: `chat`,
|
||||
`completions`, and `text`. Here are three examples of these formats:
|
||||
|
||||
`chat`:
|
||||
|
||||
```jsonl
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "How can I assistant you today."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`completions`:
|
||||
|
||||
```jsonl
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"completion": "Paris."
|
||||
}
|
||||
```
|
||||
|
||||
`text`:
|
||||
|
||||
```jsonl
|
||||
{
|
||||
"text": "This is an example for the model."
|
||||
}
|
||||
```
|
||||
|
||||
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||
each line not expected by the loader will be ignored.
|
||||
|
||||
### Hugging Face Datasets
|
||||
|
||||
To use Hugging Face datasets, first install the `datasets` package:
|
||||
|
||||
```
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
Specify the Hugging Face dataset arguments in a YAML config. For example:
|
||||
|
||||
```
|
||||
hf_dataset:
|
||||
name: "billsum"
|
||||
prompt_feature: "text"
|
||||
completion_feature: "summary"
|
||||
```
|
||||
|
||||
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||
dataset.
|
||||
|
||||
- To specify the train, valid, or test splits, set the corresponding
|
||||
`{train,valid,test}_split` argument.
|
||||
|
||||
- Arguments specified in `config` will be passed as keyword arguments to
|
||||
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||
|
||||
In general, for the `chat` and `completions` formats, Hugging Face [chat
|
||||
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
||||
the model's chat template by default. If the model does not have a chat
|
||||
template, then Hugging Face will use a default. For example, the final text in
|
||||
the `chat` example above with Hugging Face's default template becomes:
|
||||
|
||||
```text
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I assistant you today.<|im_end|>
|
||||
```
|
||||
|
||||
If you are unsure of the format to use, the `chat` or `completions` are good to
|
||||
start with. For custom requirements on the format of the dataset, use the
|
||||
`text` format to assemble the content yourself.
|
||||
|
||||
## Memory Issues
|
||||
|
||||
Fine-tuning a large model with LoRA requires a machine with a decent amount
|
||||
of memory. Here are some tips to reduce memory use should you need to do so:
|
||||
|
||||
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
|
||||
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
|
||||
more details.
|
||||
|
||||
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
|
||||
setting this to `2` or `1` will reduce memory consumption. This may slow
|
||||
things down a little, but will also reduce the memory use.
|
||||
|
||||
3. Reduce the number of layers to fine-tune with `--lora-layers`. The default
|
||||
is `16`, so you can try `8` or `4`. This reduces the amount of memory
|
||||
needed for back propagation. It may also reduce the quality of the
|
||||
fine-tuned model if you are fine-tuning with a lot of data.
|
||||
|
||||
4. Longer examples require more memory. If it makes sense for your data, one thing
|
||||
you can do is break your examples into smaller
|
||||
sequences when making the `{train, valid, test}.jsonl` files.
|
||||
|
||||
5. Gradient checkpointing lets you trade-off memory use (less) for computation
|
||||
(more) by recomputing instead of storing intermediate values needed by the
|
||||
backward pass. You can use gradient checkpointing by passing the
|
||||
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
|
||||
larger batch sizes or sequence lengths with smaller or quantized models.
|
||||
|
||||
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||
|
||||
```
|
||||
mlx_lm.lora \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--train \
|
||||
--batch-size 1 \
|
||||
--lora-layers 4 \
|
||||
--data wikisql
|
||||
```
|
||||
|
||||
The above command on an M1 Max with 32 GB runs at about 250
|
||||
tokens-per-second, using the MLX Example
|
||||
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
|
||||
data set.
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
@@ -1,22 +0,0 @@
|
||||
# Managing Models
|
||||
|
||||
You can use `mlx-lm` to manage models downloaded locally in your machine. They
|
||||
are stored in the Hugging Face cache.
|
||||
|
||||
Scan models:
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --scan
|
||||
```
|
||||
|
||||
Specify a `--pattern` to get info on a single or specific set of models:
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||
```
|
||||
|
||||
To delete a model (or multiple models):
|
||||
|
||||
```shell
|
||||
mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
# Model Merging
|
||||
|
||||
You can use `mlx-lm` to merge models and upload them to the Hugging
|
||||
Face hub or save them locally for LoRA fine tuning.
|
||||
|
||||
The main command is `mlx_lm.merge`:
|
||||
|
||||
```shell
|
||||
mlx_lm.merge --config config.yaml
|
||||
```
|
||||
|
||||
The merged model will be saved by default in `mlx_merged_model`. To see a
|
||||
full list of options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.merge --help
|
||||
```
|
||||
|
||||
Here is an example `config.yaml`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- OpenPipe/mistral-ft-optimized-1218
|
||||
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||
method: slerp
|
||||
parameters:
|
||||
t:
|
||||
- filter: self_attn
|
||||
value: [0, 0.5, 0.3, 0.7, 1]
|
||||
- filter: mlp
|
||||
value: [1, 0.5, 0.7, 0.3, 0]
|
||||
- value: 0.5
|
||||
```
|
||||
|
||||
The `models` field is a list of Hugging Face repo ids. The first model in the
|
||||
list is treated as the base model into which the remaining models are merged.
|
||||
|
||||
The `method` field is the merging method. Right now `slerp` is the only
|
||||
supported method.
|
||||
|
||||
The `parameters` are the corresponding parameters for the given `method`.
|
||||
Each parameter is a list with `filter` determining which layer the parameter
|
||||
applies to and `value` determining the actual value used. The last item in
|
||||
the list without a `filter` field is the default.
|
||||
|
||||
If `value` is a list, it specifies the start and end values for the
|
||||
corresponding segment of blocks. In the example above, the models have 32
|
||||
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
|
||||
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
|
||||
will use `np.linspace(0.5, 0.3, 8)`, and so on.
|
||||
@@ -1,10 +0,0 @@
|
||||
## Generate Text with MLX and :hugs: Hugging Face
|
||||
|
||||
This an example of large language model text generation that can pull models from
|
||||
the Hugging Face Hub.
|
||||
|
||||
For more information on this example, see the [README](../README.md) in the
|
||||
parent directory.
|
||||
|
||||
This package also supports fine tuning with LoRA or QLoRA. For more information
|
||||
see the [LoRA documentation](LORA.md).
|
||||
@@ -1,80 +0,0 @@
|
||||
# HTTP Model Server
|
||||
|
||||
You use `mlx-lm` to make an HTTP API for generating text with any supported
|
||||
model. The HTTP API is intended to be similar to the [OpenAI chat
|
||||
API](https://platform.openai.com/docs/api-reference).
|
||||
|
||||
> [!NOTE]
|
||||
> The MLX LM server is not recommended for production as it only implements
|
||||
> basic security checks.
|
||||
|
||||
Start the server with:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model <path_to_model_or_hf_repo>
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
|
||||
```
|
||||
|
||||
This will start a text generation server on port `8080` of the `localhost`
|
||||
using Mistral 7B instruct. The model will be downloaded from the provided
|
||||
Hugging Face repo if it is not already in the local cache.
|
||||
|
||||
To see a full list of options run:
|
||||
|
||||
```shell
|
||||
mlx_lm.server --help
|
||||
```
|
||||
|
||||
You can make a request to the model by running:
|
||||
|
||||
```shell
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
### Request Fields
|
||||
|
||||
- `messages`: An array of message objects representing the conversation
|
||||
history. Each message object should have a role (e.g. user, assistant) and
|
||||
content (the message text).
|
||||
|
||||
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||
the generated prompt. If not provided, the default mappings are used.
|
||||
|
||||
- `stop`: (Optional) An array of strings or a single string. Thesse are
|
||||
sequences of tokens on which the generation should stop.
|
||||
|
||||
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||
to generate. Defaults to `100`.
|
||||
|
||||
- `stream`: (Optional) A boolean indicating if the response should be
|
||||
streamed. If true, responses are sent as they are generated. Defaults to
|
||||
false.
|
||||
|
||||
- `temperature`: (Optional) A float specifying the sampling temperature.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
|
||||
Defaults to `1.0`.
|
||||
|
||||
- `repetition_context_size`: (Optional) The size of the context window for
|
||||
applying repetition penalty. Defaults to `20`.
|
||||
|
||||
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
|
||||
values. Defaults to `None`.
|
||||
|
||||
- `logprobs`: (Optional) An integer specifying the number of top tokens and
|
||||
corresponding log probabilities to return for each output in the generated
|
||||
sequence. If set, this can be any value between 1 and 10, inclusive.
|
||||
@@ -1,37 +0,0 @@
|
||||
### Packaging for PyPI
|
||||
|
||||
Install `build` and `twine`:
|
||||
|
||||
```
|
||||
pip install --user --upgrade build
|
||||
pip install --user --upgrade twine
|
||||
```
|
||||
|
||||
Generate the source distribution and wheel:
|
||||
|
||||
```
|
||||
python -m build
|
||||
```
|
||||
|
||||
> [!warning]
|
||||
> Use a test server first
|
||||
|
||||
#### Test Upload
|
||||
|
||||
Upload to test server:
|
||||
|
||||
```
|
||||
python -m twine upload --repository testpypi dist/*
|
||||
```
|
||||
|
||||
Install from test server and check that it works:
|
||||
|
||||
```
|
||||
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
|
||||
```
|
||||
|
||||
#### Upload
|
||||
|
||||
```
|
||||
python -m twine upload dist/*
|
||||
```
|
||||
@@ -1,4 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from .utils import convert, generate, load, stream_generate
|
||||
from .version import __version__
|
||||
@@ -1,62 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
from .utils import convert
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face model to MLX format"
|
||||
)
|
||||
|
||||
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
||||
parser.add_argument(
|
||||
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="Type to save the parameters, ignored if -q is given.",
|
||||
type=str,
|
||||
choices=["float16", "bfloat16", "float32"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dequantize",
|
||||
help="Dequantize a quantized model.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
convert(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,40 +0,0 @@
|
||||
from mlx_lm import generate, load
|
||||
|
||||
# Specify the checkpoint
|
||||
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
# Load the corresponding model and tokenizer
|
||||
model, tokenizer = load(path_or_hf_repo=checkpoint)
|
||||
|
||||
# Specify the prompt and conversation history
|
||||
prompt = "Why is the sky blue?"
|
||||
conversation = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Transform the prompt into the chat template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Specify the maximum number of tokens
|
||||
max_tokens = 1_000
|
||||
|
||||
# Specify if tokens and timing information will be printed
|
||||
verbose = True
|
||||
|
||||
# Some optional arguments for causal language model generation
|
||||
generation_args = {
|
||||
"temp": 0.7,
|
||||
"repetition_penalty": 1.2,
|
||||
"repetition_context_size": 20,
|
||||
"top_p": 0.95,
|
||||
}
|
||||
|
||||
# Generate a response with the specified settings
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
verbose=verbose,
|
||||
**generation_args,
|
||||
)
|
||||
@@ -1,79 +0,0 @@
|
||||
# The path to the local model directory or Hugging Face repo.
|
||||
model: "mlx_model"
|
||||
# Whether or not to train (boolean)
|
||||
train: true
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
# The PRNG seed
|
||||
seed: 0
|
||||
|
||||
# Number of layers to fine-tune
|
||||
lora_layers: 16
|
||||
|
||||
# Minibatch size.
|
||||
batch_size: 4
|
||||
|
||||
# Iterations to train for.
|
||||
iters: 1000
|
||||
|
||||
# Number of validation batches, -1 uses the entire validation set.
|
||||
val_batches: 25
|
||||
|
||||
# Adam learning rate.
|
||||
learning_rate: 1e-5
|
||||
|
||||
# Number of training steps between loss reporting.
|
||||
steps_per_report: 10
|
||||
|
||||
# Number of training steps between validations.
|
||||
steps_per_eval: 200
|
||||
|
||||
# Load path to resume training with the given adapter weights.
|
||||
resume_adapter_file: null
|
||||
|
||||
# Save/load path for the trained adapter weights.
|
||||
adapter_path: "adapters"
|
||||
|
||||
# Save the model every N iterations.
|
||||
save_every: 100
|
||||
|
||||
# Evaluate on the test set after training
|
||||
test: false
|
||||
|
||||
# Number of test set batches, -1 uses the entire test set.
|
||||
test_batches: 100
|
||||
|
||||
# Maximum sequence length.
|
||||
max_seq_length: 2048
|
||||
|
||||
# Use gradient checkpointing to reduce memory use.
|
||||
grad_checkpoint: false
|
||||
|
||||
# Use DoRA instead of LoRA.
|
||||
use_dora: false
|
||||
|
||||
# LoRA parameters can only be specified in a config file
|
||||
lora_parameters:
|
||||
# The layer keys to apply LoRA to.
|
||||
# These will be applied for the last lora_layers
|
||||
keys: ["self_attn.q_proj", "self_attn.v_proj"]
|
||||
rank: 8
|
||||
scale: 20.0
|
||||
dropout: 0.0
|
||||
|
||||
# Schedule can only be specified in a config file, uncomment to use.
|
||||
#lr_schedule:
|
||||
# name: cosine_decay
|
||||
# warmup: 100 # 0 for no warmup
|
||||
# warmup_init: 1e-7 # 0 if not specified
|
||||
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
||||
|
||||
#hf_dataset:
|
||||
# name: "billsum"
|
||||
# train_split: "train[:1000]"
|
||||
# valid_split: "train[-100:]"
|
||||
# prompt_feature: "text"
|
||||
# completion_feature: "summary"
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
models:
|
||||
- OpenPipe/mistral-ft-optimized-1218
|
||||
- mlabonne/NeuralHermes-2.5-Mistral-7B
|
||||
method: slerp
|
||||
parameters:
|
||||
t:
|
||||
- filter: self_attn
|
||||
value: [0, 0.5, 0.3, 0.7, 1]
|
||||
- filter: mlp
|
||||
value: [1, 0.5, 0.7, 0.3, 0]
|
||||
- value: 0.5
|
||||
@@ -1,131 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from .gguf import convert_to_gguf
|
||||
from .tuner.dora import DoRALinear
|
||||
from .tuner.lora import LoRALinear, LoRASwitchLinear
|
||||
from .tuner.utils import apply_lora_layers, dequantize
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fuse fine-tuned adapters into the base model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
default="lora_fused_model",
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
default="adapters",
|
||||
help="Path to the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--de-quantize",
|
||||
help="Generate a de-quantized model.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-gguf",
|
||||
help="Export model weights in GGUF format.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gguf-path",
|
||||
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
|
||||
default="ggml-model-f16.gguf",
|
||||
type=str,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
print("Loading pretrained model")
|
||||
args = parse_arguments()
|
||||
|
||||
model_path = get_model_path(args.model)
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = apply_lora_layers(model, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.to_linear())
|
||||
for n, m in model.named_modules()
|
||||
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
|
||||
]
|
||||
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
||||
if args.de_quantize:
|
||||
print("De-quantizing model")
|
||||
model = dequantize(model)
|
||||
|
||||
weights = dict(tree_flatten(model.parameters()))
|
||||
|
||||
save_path = Path(args.save_path)
|
||||
|
||||
save_weights(save_path, weights)
|
||||
|
||||
py_files = glob.glob(str(model_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, save_path)
|
||||
|
||||
tokenizer.save_pretrained(save_path)
|
||||
|
||||
if args.de_quantize:
|
||||
config.pop("quantization", None)
|
||||
|
||||
save_config(config, config_path=save_path / "config.json")
|
||||
|
||||
if args.export_gguf:
|
||||
model_type = config["model_type"]
|
||||
if model_type not in ["llama", "mixtral", "mistral"]:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for GGUF conversion."
|
||||
)
|
||||
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
|
||||
|
||||
if args.upload_repo is not None:
|
||||
hf_path = args.hf_path or (
|
||||
args.model if not Path(args.model).exists() else None
|
||||
)
|
||||
if hf_path is None:
|
||||
raise ValueError(
|
||||
"Must provide original Hugging Face repo to upload local model."
|
||||
)
|
||||
upload_to_hub(args.save_path, args.upload_repo, hf_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,161 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_MODEL_PATH = "mlx_model"
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="mlx_model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Enable trusting remote code for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
"-m",
|
||||
type=int,
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
parser.add_argument(
|
||||
"--ignore-chat-template",
|
||||
action="store_true",
|
||||
help="Use the raw prompt without the tokenizer's chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-default-chat-template",
|
||||
action="store_true",
|
||||
help="Use the default chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--colorize",
|
||||
action="store_true",
|
||||
help="Colorize output based on T[0] probability",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-limit-gb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the MLX cache limit in GB",
|
||||
required=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def colorprint(color, s):
|
||||
color_codes = {
|
||||
"black": 30,
|
||||
"red": 31,
|
||||
"green": 32,
|
||||
"yellow": 33,
|
||||
"blue": 34,
|
||||
"magenta": 35,
|
||||
"cyan": 36,
|
||||
"white": 39,
|
||||
}
|
||||
ccode = color_codes.get(color, 30)
|
||||
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
|
||||
|
||||
|
||||
def colorprint_by_t0(s, t0):
|
||||
if t0 > 0.95:
|
||||
color = "white"
|
||||
elif t0 > 0.70:
|
||||
color = "green"
|
||||
elif t0 > 0.30:
|
||||
color = "yellow"
|
||||
else:
|
||||
color = "red"
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config=tokenizer_config,
|
||||
)
|
||||
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_tokens,
|
||||
verbose=True,
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,313 +0,0 @@
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class TokenType(IntEnum):
|
||||
NORMAL = 1
|
||||
UNKNOWN = 2
|
||||
CONTROL = 3
|
||||
USER_DEFINED = 4
|
||||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
|
||||
class GGMLFileType(IntEnum):
|
||||
GGML_TYPE_F16 = 1
|
||||
|
||||
|
||||
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
|
||||
class HfVocab:
|
||||
def __init__(
|
||||
self,
|
||||
fname_tokenizer: Path,
|
||||
fname_added_tokens: Optional[Union[Path, None]] = None,
|
||||
) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
fname_tokenizer,
|
||||
cache_dir=fname_tokenizer,
|
||||
local_files_only=True,
|
||||
)
|
||||
self.added_tokens_list = []
|
||||
self.added_tokens_dict = dict()
|
||||
self.added_tokens_ids = set()
|
||||
for tok, tokidx in sorted(
|
||||
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
|
||||
):
|
||||
if tokidx >= self.tokenizer.vocab_size:
|
||||
self.added_tokens_list.append(tok)
|
||||
self.added_tokens_dict[tok] = tokidx
|
||||
self.added_tokens_ids.add(tokidx)
|
||||
self.specials = {
|
||||
tok: self.tokenizer.get_vocab()[tok]
|
||||
for tok in self.tokenizer.all_special_tokens
|
||||
}
|
||||
self.special_ids = set(self.tokenizer.all_special_ids)
|
||||
self.vocab_size_base = self.tokenizer.vocab_size
|
||||
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
|
||||
self.fname_tokenizer = fname_tokenizer
|
||||
self.fname_added_tokens = fname_added_tokens
|
||||
|
||||
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
reverse_vocab = {
|
||||
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
|
||||
}
|
||||
for token_id in range(self.vocab_size_base):
|
||||
if token_id in self.added_tokens_ids:
|
||||
continue
|
||||
token_text = reverse_vocab[token_id].encode("utf-8")
|
||||
yield token_text, self.get_token_score(token_id), self.get_token_type(
|
||||
token_id, token_text, self.special_ids
|
||||
)
|
||||
|
||||
def get_token_type(
|
||||
self, token_id: int, token_text: bytes, special_ids: Set[int]
|
||||
) -> TokenType:
|
||||
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
|
||||
return TokenType.BYTE
|
||||
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
|
||||
|
||||
def get_token_score(self, token_id: int) -> float:
|
||||
return -1000.0
|
||||
|
||||
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
for text in self.added_tokens_list:
|
||||
if text in self.specials:
|
||||
toktype = self.get_token_type(
|
||||
self.specials[text], b"", self.special_ids
|
||||
)
|
||||
score = self.get_token_score(self.specials[text])
|
||||
else:
|
||||
toktype = TokenType.USER_DEFINED
|
||||
score = -1000.0
|
||||
yield text.encode("utf-8"), score, toktype
|
||||
|
||||
def has_newline_token(self):
|
||||
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
|
||||
|
||||
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
|
||||
yield from self.hf_tokens()
|
||||
yield from self.added_tokens()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
@staticmethod
|
||||
def load(path: Path) -> "HfVocab":
|
||||
added_tokens_path = path.parent / "added_tokens.json"
|
||||
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
|
||||
|
||||
|
||||
def translate_weight_names(name):
|
||||
name = name.replace("model.layers.", "blk.")
|
||||
# for mixtral gate
|
||||
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
|
||||
# for mixtral experts ffns
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
|
||||
replacement = r"ffn_gate.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
|
||||
replacement = r"ffn_down.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
|
||||
replacement = r"ffn_up.\1.weight"
|
||||
name = re.sub(pattern, replacement, name)
|
||||
|
||||
name = name.replace("mlp.gate_proj", "ffn_gate")
|
||||
name = name.replace("mlp.down_proj", "ffn_down")
|
||||
name = name.replace("mlp.up_proj", "ffn_up")
|
||||
name = name.replace("self_attn.q_proj", "attn_q")
|
||||
name = name.replace("self_attn.k_proj", "attn_k")
|
||||
name = name.replace("self_attn.v_proj", "attn_v")
|
||||
name = name.replace("self_attn.o_proj", "attn_output")
|
||||
name = name.replace("input_layernorm", "attn_norm")
|
||||
name = name.replace("post_attention_layernorm", "ffn_norm")
|
||||
name = name.replace("model.embed_tokens", "token_embd")
|
||||
name = name.replace("model.norm", "output_norm")
|
||||
name = name.replace("lm_head", "output")
|
||||
return name
|
||||
|
||||
|
||||
def permute_weights(weights, n_head, n_head_kv=None):
|
||||
if n_head_kv is not None and n_head != n_head_kv:
|
||||
n_head = n_head_kv
|
||||
reshaped = weights.reshape(
|
||||
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
|
||||
)
|
||||
swapped = reshaped.swapaxes(1, 2)
|
||||
final_shape = weights.shape
|
||||
return swapped.reshape(final_shape)
|
||||
|
||||
|
||||
def prepare_metadata(config, vocab):
|
||||
metadata = {
|
||||
"general.name": "llama",
|
||||
"llama.context_length": (
|
||||
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
|
||||
if config.get("max_position_embeddings") is not None
|
||||
else None
|
||||
),
|
||||
"llama.embedding_length": (
|
||||
mx.array(config["hidden_size"], dtype=mx.uint32)
|
||||
if config.get("hidden_size") is not None
|
||||
else None
|
||||
),
|
||||
"llama.block_count": (
|
||||
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
|
||||
if config.get("num_hidden_layers") is not None
|
||||
else None
|
||||
),
|
||||
"llama.feed_forward_length": (
|
||||
mx.array(config["intermediate_size"], dtype=mx.uint32)
|
||||
if config.get("intermediate_size") is not None
|
||||
else None
|
||||
),
|
||||
"llama.rope.dimension_count": (
|
||||
mx.array(
|
||||
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
|
||||
)
|
||||
if config.get("hidden_size") is not None
|
||||
and config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.head_count": (
|
||||
mx.array(config["num_attention_heads"], dtype=mx.uint32)
|
||||
if config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.head_count_kv": (
|
||||
mx.array(
|
||||
config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
if config.get("num_attention_heads") is not None
|
||||
else None
|
||||
),
|
||||
"llama.expert_count": (
|
||||
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
|
||||
if config.get("num_local_experts") is not None
|
||||
else None
|
||||
),
|
||||
"llama.expert_used_count": (
|
||||
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
|
||||
if config.get("num_experts_per_tok") is not None
|
||||
else None
|
||||
),
|
||||
"llama.attention.layer_norm_rms_epsilon": (
|
||||
mx.array(config.get("rms_norm_eps", 1e-05))
|
||||
if config.get("rms_norm_eps") is not None
|
||||
else None
|
||||
),
|
||||
"llama.rope.freq_base": (
|
||||
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
|
||||
if config.get("rope_theta") is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
rope_scaling = config.get("rope_scaling")
|
||||
if rope_scaling is not None and (typ := rope_scaling.get("type")):
|
||||
rope_factor = rope_scaling.get("factor")
|
||||
f_rope_scale = rope_factor
|
||||
if typ == "linear":
|
||||
rope_scaling_type = "linear"
|
||||
metadata["llama.rope.scaling.type"] = rope_scaling_type
|
||||
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
|
||||
|
||||
metadata["general.file_type"] = mx.array(
|
||||
GGMLFileType.GGML_TYPE_F16.value,
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
metadata["general.quantization_version"] = mx.array(
|
||||
GGMLFileType.GGML_TYPE_F16.value,
|
||||
dtype=mx.uint32,
|
||||
)
|
||||
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
|
||||
metadata["general.architecture"] = "llama"
|
||||
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
|
||||
|
||||
# add metadata for vocab
|
||||
metadata["tokenizer.ggml.model"] = "llama"
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype.value)
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
metadata["tokenizer.ggml.tokens"] = tokens
|
||||
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
|
||||
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
|
||||
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
|
||||
vocab.tokenizer.bos_token_id, dtype=mx.uint32
|
||||
)
|
||||
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
|
||||
vocab.tokenizer.eos_token_id, dtype=mx.uint32
|
||||
)
|
||||
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
|
||||
vocab.tokenizer.unk_token_id, dtype=mx.uint32
|
||||
)
|
||||
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
return metadata
|
||||
|
||||
|
||||
def convert_to_gguf(
|
||||
model_path: Union[str, Path],
|
||||
weights: dict,
|
||||
config: dict,
|
||||
output_file_path: str,
|
||||
):
|
||||
if isinstance(model_path, str):
|
||||
model_path = Path(model_path)
|
||||
|
||||
quantization = config.get("quantization", None)
|
||||
if quantization:
|
||||
raise NotImplementedError(
|
||||
"Conversion of quantized models is not yet supported."
|
||||
)
|
||||
print("Converting to GGUF format")
|
||||
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
|
||||
weights = {
|
||||
k: (
|
||||
permute_weights(
|
||||
v, config["num_attention_heads"], config["num_attention_heads"]
|
||||
)
|
||||
if "self_attn.q_proj.weight" in k
|
||||
else (
|
||||
permute_weights(
|
||||
v, config["num_attention_heads"], config["num_key_value_heads"]
|
||||
)
|
||||
if "self_attn.k_proj.weight" in k
|
||||
else v
|
||||
)
|
||||
)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
# rename weights for gguf format
|
||||
weights = {translate_weight_names(k): v for k, v in weights.items()}
|
||||
|
||||
if not (model_path / "tokenizer.json").exists():
|
||||
raise ValueError("Tokenizer json not found")
|
||||
|
||||
vocab = HfVocab.load(model_path)
|
||||
metadata = prepare_metadata(config, vocab)
|
||||
|
||||
weights = {
|
||||
k: (
|
||||
v.astype(mx.float32).astype(mx.float16)
|
||||
if v.dtype == mx.bfloat16
|
||||
else v.astype(mx.float32) if "norm" in k else v
|
||||
)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
|
||||
output_file_path = output_file_path
|
||||
mx.save_gguf(output_file_path, weights, metadata)
|
||||
print(f"Converted GGUF model saved as: {output_file_path}")
|
||||
@@ -1,278 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
apply_lora_layers,
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
print_trainable_parameters,
|
||||
)
|
||||
from .utils import load, save_config
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
yaml_loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"lora_layers": 16,
|
||||
"batch_size": 4,
|
||||
"iters": 1000,
|
||||
"val_batches": 25,
|
||||
"learning_rate": 1e-5,
|
||||
"steps_per_report": 10,
|
||||
"steps_per_eval": 200,
|
||||
"resume_adapter_file": None,
|
||||
"adapter_path": "adapters",
|
||||
"save_every": 100,
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"use_dora": False,
|
||||
}
|
||||
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
)
|
||||
|
||||
# Training args
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Do training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
type=int,
|
||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
|
||||
parser.add_argument("--iters", type=int, help="Iterations to train for.")
|
||||
parser.add_argument(
|
||||
"--val-batches",
|
||||
type=int,
|
||||
help="Number of validation batches, -1 uses the entire validation set.",
|
||||
)
|
||||
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
|
||||
parser.add_argument(
|
||||
"--steps-per-report",
|
||||
type=int,
|
||||
help="Number of training steps between loss reporting.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps-per-eval",
|
||||
type=int,
|
||||
help="Number of training steps between validations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training with the given adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-batches",
|
||||
type=int,
|
||||
help="Number of test set batches, -1 uses the entire test set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-seq-length",
|
||||
type=int,
|
||||
help="Maximum sequence length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
default=None,
|
||||
help="A YAML configuration file with the training options",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-checkpoint",
|
||||
action="store_true",
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def train_model(
|
||||
args,
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
train_set,
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
# init training args
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
|
||||
model.eval()
|
||||
|
||||
test_loss = evaluate(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
def run(args, training_callback: TrainingCallback = None):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print("Loading pretrained model")
|
||||
model, tokenizer = load(args.model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
|
||||
if args.test and not args.train:
|
||||
# Allow testing without LoRA layers by providing empty path
|
||||
if args.adapter_path != "":
|
||||
apply_lora_layers(model, args.adapter_path)
|
||||
|
||||
elif args.train:
|
||||
print("Training")
|
||||
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
|
||||
else:
|
||||
raise ValueError("Must provide at least one of --train or --test")
|
||||
|
||||
if args.test:
|
||||
print("Testing")
|
||||
evaluate_model(args, model, tokenizer, test_set)
|
||||
|
||||
|
||||
def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
config = args.config
|
||||
args = vars(args)
|
||||
if config:
|
||||
print("Loading configuration file", config)
|
||||
with open(config, "r") as file:
|
||||
config = yaml.load(file, yaml_loader)
|
||||
# Prefer parameters from command-line arguments
|
||||
for k, v in config.items():
|
||||
if args.get(k, None) is None:
|
||||
args[k] = v
|
||||
|
||||
# Update defaults for unspecified parameters
|
||||
for k, v in CONFIG_DEFAULTS.items():
|
||||
if args.get(k, None) is None:
|
||||
args[k] = v
|
||||
run(types.SimpleNamespace(**args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,121 +0,0 @@
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from transformers.commands.user import tabulate
|
||||
|
||||
|
||||
def ask_for_confirmation(message: str) -> bool:
|
||||
y = ("y", "yes", "1")
|
||||
n = ("n", "no", "0")
|
||||
all_values = y + n + ("",)
|
||||
full_message = f"{message} (Y/n) "
|
||||
while True:
|
||||
answer = input(full_message).lower()
|
||||
if answer == "":
|
||||
return False
|
||||
if answer in y:
|
||||
return True
|
||||
if answer in n:
|
||||
return False
|
||||
print(f"Invalid input. Must be one of {all_values}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MLX Model Cache.")
|
||||
parser.add_argument(
|
||||
"--scan",
|
||||
action="store_true",
|
||||
help="Scan Hugging Face cache for mlx models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
action="store_true",
|
||||
help="Delete models matching the given pattern.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
help="Model repos contain the pattern.",
|
||||
default="mlx",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.scan:
|
||||
print(
|
||||
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
|
||||
)
|
||||
hf_cache_info = scan_cache_dir()
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
repo.repo_type,
|
||||
"{:>12}".format(repo.size_on_disk_str),
|
||||
repo.nb_files,
|
||||
repo.last_accessed_str,
|
||||
repo.last_modified_str,
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in sorted(
|
||||
hf_cache_info.repos, key=lambda repo: repo.repo_path
|
||||
)
|
||||
if args.pattern in repo.repo_id
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"REPO TYPE",
|
||||
"SIZE ON DISK",
|
||||
"NB FILES",
|
||||
"LAST_ACCESSED",
|
||||
"LAST_MODIFIED",
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if args.delete:
|
||||
print(f'Deleting models matching pattern "{args.pattern}"')
|
||||
hf_cache_info = scan_cache_dir()
|
||||
|
||||
repos = [
|
||||
repo
|
||||
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
|
||||
if args.pattern in repo.repo_id
|
||||
]
|
||||
if repos:
|
||||
print(
|
||||
tabulate(
|
||||
rows=[
|
||||
[
|
||||
repo.repo_id,
|
||||
str(repo.repo_path),
|
||||
]
|
||||
for repo in repos
|
||||
],
|
||||
headers=[
|
||||
"REPO ID",
|
||||
"LOCAL PATH",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
confirmed = ask_for_confirmation(f"Confirm deletion ?")
|
||||
if confirmed:
|
||||
for model_info in repos:
|
||||
for revision in sorted(
|
||||
model_info.revisions, key=lambda revision: revision.commit_hash
|
||||
):
|
||||
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
|
||||
strategy.execute()
|
||||
print("Model(s) deleted.")
|
||||
else:
|
||||
print("Deletion is cancelled. Do nothing.")
|
||||
else:
|
||||
print(f"No models found.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,172 +0,0 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
save_config,
|
||||
save_weights,
|
||||
upload_to_hub,
|
||||
)
|
||||
|
||||
|
||||
def configure_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Configures and returns the argument parser for the script.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Configured argument parser.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Merge multiple models.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="Path to the YAML config.")
|
||||
parser.add_argument(
|
||||
"--mlx-path",
|
||||
type=str,
|
||||
default="mlx_merged_model",
|
||||
help="Path to save the MLX model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-repo",
|
||||
help="The Hugging Face repo to upload the model to.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def slerp(t, w1, w2, eps=1e-5):
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
|
||||
Args:
|
||||
t (float): Interpolation weight in [0.0, 1.0]
|
||||
w1 (mx.array): First input
|
||||
w2 (mx.array): Second input
|
||||
eps (float): Constant for numerical stability
|
||||
Returns:
|
||||
mx.array: Interpolated result
|
||||
"""
|
||||
t = float(t)
|
||||
if t == 0:
|
||||
return w1
|
||||
elif t == 1:
|
||||
return w2
|
||||
# Normalize
|
||||
v1 = w1 / mx.linalg.norm(w1)
|
||||
v2 = w2 / mx.linalg.norm(w2)
|
||||
# Angle
|
||||
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
|
||||
theta = mx.arccos(dot)
|
||||
sin_theta = mx.sin(theta + eps)
|
||||
s1 = mx.sin(theta * (1 - t)) / sin_theta
|
||||
s2 = mx.sin(theta * t) / sin_theta
|
||||
return s1 * w1 + s2 * w2
|
||||
|
||||
|
||||
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
|
||||
method = config.get("method", None)
|
||||
if method != "slerp":
|
||||
raise ValueError(f"Merge method {method} not supported")
|
||||
|
||||
num_layers = len(model.layers)
|
||||
|
||||
def unpack_values(vals):
|
||||
if isinstance(vals, (int, float)):
|
||||
return np.full(num_layers, vals)
|
||||
bins = len(vals) - 1
|
||||
sizes = [num_layers // bins] * bins
|
||||
sizes[-1] = num_layers - sum(sizes[:-1])
|
||||
return np.concatenate(
|
||||
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
|
||||
)
|
||||
|
||||
param_list = config["parameters"]["t"]
|
||||
params = {}
|
||||
filter_keys = set()
|
||||
for pl in param_list[:-1]:
|
||||
params[pl["filter"]] = unpack_values(pl["value"])
|
||||
filter_keys.add(pl["filter"])
|
||||
default = unpack_values(param_list[-1]["value"])
|
||||
|
||||
for e in range(num_layers):
|
||||
bl = base_model.layers[e]
|
||||
l = model.layers[e]
|
||||
base_weights = bl.parameters()
|
||||
weights = l.parameters()
|
||||
for k, w1 in base_weights.items():
|
||||
w2 = weights[k]
|
||||
t = params.get(k, default)[e]
|
||||
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
|
||||
base_model.update(base_weights)
|
||||
|
||||
|
||||
def merge(
|
||||
config: str,
|
||||
mlx_path: str = "mlx_model",
|
||||
upload_repo: Optional[str] = None,
|
||||
):
|
||||
with open(config, "r") as fid:
|
||||
merge_conf = yaml.safe_load(fid)
|
||||
print("[INFO] Loading")
|
||||
|
||||
model_paths = merge_conf.get("models", [])
|
||||
if len(model_paths) < 2:
|
||||
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
|
||||
|
||||
# Load all models
|
||||
base_hf_path = model_paths[0]
|
||||
base_path = get_model_path(base_hf_path)
|
||||
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
|
||||
models = []
|
||||
for mp in model_paths[1:]:
|
||||
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
|
||||
base_type = base_config["model_type"]
|
||||
model_type = model_config["model_type"]
|
||||
if base_type != model_type:
|
||||
raise ValueError(
|
||||
f"Can only merge models of the same type,"
|
||||
f" but got {base_type} and {model_type}."
|
||||
)
|
||||
models.append(model)
|
||||
|
||||
# Merge models into base model
|
||||
for m in models:
|
||||
merge_models(base_model, m, merge_conf)
|
||||
|
||||
# Save base model
|
||||
mlx_path = Path(mlx_path)
|
||||
weights = dict(tree_flatten(base_model.parameters()))
|
||||
del models, base_model
|
||||
save_weights(mlx_path, weights, donate_weights=True)
|
||||
py_files = glob.glob(str(base_path / "*.py"))
|
||||
for file in py_files:
|
||||
shutil.copy(file, mlx_path)
|
||||
|
||||
tokenizer.save_pretrained(mlx_path)
|
||||
|
||||
save_config(config, config_path=mlx_path / "config.json")
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, base_hf_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = configure_parser()
|
||||
args = parser.parse_args()
|
||||
merge(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,62 +0,0 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
||||
v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -1,201 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 8192
|
||||
num_hidden_layers: int = 40
|
||||
intermediate_size: int = 22528
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 64
|
||||
rope_theta: float = 8000000.0
|
||||
vocab_size: int = 256000
|
||||
layer_norm_eps: float = 1e-05
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
use_qk_norm: bool = False
|
||||
|
||||
|
||||
class LayerNorm2D(nn.Module):
|
||||
|
||||
def __init__(self, d1, d2, eps):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((d1, d2))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
attetion_bias = args.attention_bias
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.use_qk_norm = args.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
|
||||
self.k_norm = LayerNorm2D(
|
||||
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.n_heads, -1)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
||||
if self.use_qk_norm:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class CohereModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = CohereModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = out * self.model.args.logit_scale
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,261 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
d_model: int
|
||||
ffn_config: dict
|
||||
attn_config: dict
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_heads = args.n_heads
|
||||
self.d_model = args.d_model
|
||||
self.head_dim = args.d_model // args.n_heads
|
||||
self.num_key_value_heads = args.attn_config["kv_n_heads"]
|
||||
self.clip_qkv = args.attn_config["clip_qkv"]
|
||||
self.rope_theta = args.attn_config["rope_theta"]
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.Wqkv = nn.Linear(
|
||||
args.d_model,
|
||||
(self.num_key_value_heads * 2 + self.num_heads) * self.head_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv)
|
||||
splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads]
|
||||
queries, keys, values = mx.split(qkv, splits, axis=-1)
|
||||
|
||||
B, L, D = x.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class NormAttnNorm(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.norm_1 = nn.LayerNorm(args.d_model, bias=False)
|
||||
self.norm_2 = nn.LayerNorm(args.d_model, bias=False)
|
||||
self.attn = Attention(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||
x = h + x
|
||||
return x, self.norm_2(x)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, d_model: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.v1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||
self.w1 = nn.Linear(d_model, ffn_dim, bias=False)
|
||||
self.w2 = nn.Linear(ffn_dim, d_model, bias=False)
|
||||
self.act_fn = nn.silu
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class Router(nn.Module):
|
||||
def __init__(self, d_model: int, num_experts: int):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(d_model, num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class SparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.d_model = args.d_model
|
||||
self.ffn_dim = args.ffn_config["ffn_hidden_size"]
|
||||
self.num_experts = args.ffn_config["moe_num_experts"]
|
||||
self.num_experts_per_tok = args.ffn_config["moe_top_k"]
|
||||
|
||||
self.router = Router(self.d_model, self.num_experts)
|
||||
self.experts = [
|
||||
MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
ne = self.num_experts_per_tok
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
gates = self.router(x)
|
||||
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
|
||||
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
|
||||
scores = scores.astype(x.dtype)
|
||||
|
||||
if self.training:
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
continue
|
||||
y[idx1, idx2] = expert(x[idx1])
|
||||
|
||||
y = (y * scores[:, :, None]).sum(axis=1)
|
||||
else:
|
||||
y = []
|
||||
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
|
||||
yt = (yt * st).sum(axis=-1)
|
||||
y.append(yt)
|
||||
y = mx.stack(y, axis=0)
|
||||
|
||||
return y.reshape(orig_shape)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.ffn = SparseMoeBlock(args)
|
||||
self.norm_attn_norm = NormAttnNorm(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r, h = self.norm_attn_norm(x, mask, cache)
|
||||
out = self.ffn(h) + r
|
||||
return out
|
||||
|
||||
|
||||
class DBRX(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.wte = nn.Embedding(args.vocab_size, args.d_model)
|
||||
self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)]
|
||||
self.norm_f = nn.LayerNorm(args.d_model, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
||||
for layer, c in zip(self.blocks, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm_f(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.transformer = DBRX(args)
|
||||
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.blocks
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Split experts into sub matrices
|
||||
num_experts = self.args.ffn_config["moe_num_experts"]
|
||||
dim = self.args.ffn_config["ffn_hidden_size"]
|
||||
|
||||
pattern = "experts.mlp"
|
||||
new_weights = {k: v for k, v in weights.items() if pattern not in k}
|
||||
for k, v in weights.items():
|
||||
if pattern in k:
|
||||
experts = [
|
||||
(k.replace(".mlp", f".{e}") + ".weight", sv)
|
||||
for e, sv in enumerate(mx.split(v, num_experts, axis=0))
|
||||
]
|
||||
if k.endswith("w2"):
|
||||
experts = [(s, sv.T) for s, sv in experts]
|
||||
new_weights.update(experts)
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.attn_config["kv_n_heads"]
|
||||
@@ -1,468 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "deepseek_v2"
|
||||
vocab_size: int = 102400
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
moe_intermediate_size: int = 1407
|
||||
num_hidden_layers: int = 30
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_routed_experts: Optional[int] = None
|
||||
routed_scaling_factor: float = 1.0
|
||||
kv_lora_rank: int = 512
|
||||
q_lora_rank: int = 1536
|
||||
qk_rope_head_dim: int = 64
|
||||
v_head_dim: int = 128
|
||||
qk_nope_head_dim: int = 128
|
||||
topk_method: str = "gready"
|
||||
n_group: Optional[int] = None
|
||||
topk_group: Optional[int] = None
|
||||
num_experts_per_tok: Optional[int] = None
|
||||
moe_layer_freq: int = 1
|
||||
first_k_dense_replace: int = 0
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Optional[Dict] = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
def yarn_find_correction_dim(
|
||||
num_rotations, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
|
||||
def yarn_find_correction_range(
|
||||
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
|
||||
):
|
||||
low = math.floor(
|
||||
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
high = math.ceil(
|
||||
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
||||
)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
|
||||
ramp_func = mx.clip(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
scaling_factor=1.0,
|
||||
original_max_position_embeddings=4096,
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
mscale=1,
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.scaling_factor = scaling_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = mscale
|
||||
self.mscale_all_dim = mscale_all_dim
|
||||
|
||||
self.max_seq_len_cached = None
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._inv_freq = None
|
||||
self.set_cos_sin_cache(max_position_embeddings)
|
||||
|
||||
def set_cos_sin_cache(self, seq_len):
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.dim
|
||||
freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
|
||||
freq_inter = 1.0 / (
|
||||
self.scaling_factor
|
||||
* self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.original_max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self._inv_freq = inv_freq
|
||||
|
||||
t = mx.arange(seq_len, dtype=mx.float32)
|
||||
freqs = mx.outer(t, inv_freq)
|
||||
|
||||
mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
|
||||
self.scaling_factor, self.mscale_all_dim
|
||||
)
|
||||
|
||||
self._cos_cached = mx.cos(freqs) * mscale
|
||||
self._sin_cached = mx.sin(freqs) * mscale
|
||||
|
||||
def apply_rotary_pos_emb(self, x, cos, sin):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * cos - x2 * sin
|
||||
rx2 = x1 * sin + x2 * cos
|
||||
return mx.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
seq_len = offset + x.shape[2]
|
||||
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
||||
self.set_cos_sin_cache(seq_len=seq_len)
|
||||
|
||||
if self._cos_cached.dtype != x.dtype:
|
||||
self._cos_cached = self._cos_cached.astype(x.dtype)
|
||||
self._sin_cached = self._sin_cached.astype(x.dtype)
|
||||
|
||||
return self.apply_rotary_pos_emb(
|
||||
x,
|
||||
self._cos_cached[offset:seq_len],
|
||||
self._sin_cached[offset:seq_len],
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
|
||||
self.scale = self.q_head_dim**-0.5
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(
|
||||
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
|
||||
)
|
||||
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(
|
||||
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads
|
||||
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV2YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(x)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
|
||||
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(x)
|
||||
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
|
||||
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
)
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
|
||||
|
||||
def __call__(self, x):
|
||||
gates = x @ self.weight.T
|
||||
|
||||
scores = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
if self.topk_method == "group_limited_greedy":
|
||||
bsz, seq_len = x.shape[:2]
|
||||
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
|
||||
group_scores = scores.max(axis=-1)
|
||||
k = self.n_group - self.topk_group
|
||||
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
|
||||
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
|
||||
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
|
||||
scores[batch_idx, seq_idx, group_idx] = 0.0
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
return inds, scores
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
inds, scores = self.gate(x)
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(x)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV2Attention(config)
|
||||
self.mlp = (
|
||||
DeepseekV2MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV2MLP(config)
|
||||
)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [
|
||||
DeepseekV2DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV2Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,184 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = GemmaModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,211 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
attn_logit_softcapping: float = 50.0
|
||||
final_logit_softcapping: float = 30.0
|
||||
query_pre_attn_scalar: float = 144.0
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
|
||||
self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
self.attn_logit_softcapping = args.attn_logit_softcapping
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries * self.scale
|
||||
|
||||
if self.repeats > 1:
|
||||
queries = queries.reshape(
|
||||
B, self.n_kv_heads, self.repeats, L, self.head_dim
|
||||
)
|
||||
keys = mx.expand_dims(keys, 2)
|
||||
values = mx.expand_dims(values, 2)
|
||||
|
||||
scores = queries @ keys.swapaxes(-1, -2)
|
||||
scores = mx.tanh(scores / self.attn_logit_softcapping)
|
||||
scores *= self.attn_logit_softcapping
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, precise=True, axis=-1)
|
||||
output = scores @ values
|
||||
if self.repeats > 1:
|
||||
output = output.reshape(B, self.n_heads, L, self.head_dim)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
|
||||
h = x + self.post_attention_layernorm(r)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
|
||||
mx.float32
|
||||
)
|
||||
out = h + self.post_feedforward_layernorm(r)
|
||||
return out
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.final_logit_softcapping = args.final_logit_softcapping
|
||||
self.model = GemmaModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = mx.tanh(out / self.final_logit_softcapping)
|
||||
out = out * self.final_logit_softcapping
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,207 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_ctx: int
|
||||
n_embd: int
|
||||
n_head: int
|
||||
n_layer: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
|
||||
|
||||
self.n_embd = args.n_embd
|
||||
self.n_head = args.n_head
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.n_embd = args.n_embd
|
||||
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
|
||||
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(
|
||||
self.n_embd,
|
||||
eps=self.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_embd = args.n_embd
|
||||
self.n_positions = args.n_positions
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layer = args.n_layer
|
||||
self.layer_norm_epsilon = args.layer_norm_epsilon
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
|
||||
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GPT2Model(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model.wte.as_linear(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
for i in range(self.args.n_layer):
|
||||
if f"h.{i}.attn.bias" in weights:
|
||||
del weights[f"h.{i}.attn.bias"]
|
||||
if f"h.{i}.attn.c_attn.weight" in weights:
|
||||
weights[f"h.{i}.attn.c_attn.weight"] = weights[
|
||||
f"h.{i}.attn.c_attn.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.attn.c_proj.weight" in weights:
|
||||
weights[f"h.{i}.attn.c_proj.weight"] = weights[
|
||||
f"h.{i}.attn.c_proj.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.mlp.c_fc.weight" in weights:
|
||||
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
|
||||
f"h.{i}.mlp.c_fc.weight"
|
||||
].transpose(1, 0)
|
||||
if f"h.{i}.mlp.c_proj.weight" in weights:
|
||||
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
|
||||
f"h.{i}.mlp.c_proj.weight"
|
||||
].transpose(1, 0)
|
||||
for weight in weights:
|
||||
if not weight.startswith("model."):
|
||||
new_weights[f"model.{weight}"] = weights[weight]
|
||||
else:
|
||||
new_weights[weight] = weights[weight]
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,195 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
n_embd: int
|
||||
n_layer: int
|
||||
n_inner: int
|
||||
n_head: int
|
||||
n_positions: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
multi_query: bool = True
|
||||
attention_bias: bool = True
|
||||
mlp_bias: bool = True
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = 1 if self.multi_query else self.n_head
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim = args.n_embd
|
||||
self.n_heads = n_heads = args.n_head
|
||||
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
|
||||
|
||||
self.head_dim = head_dim = dim // n_heads
|
||||
|
||||
self.kv_dim = n_kv_heads * head_dim
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
|
||||
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.c_attn(x)
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.n_embd
|
||||
hidden_dim = args.n_inner
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_head = args.n_head
|
||||
self.n_embd = args.n_embd
|
||||
self.attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.ln_2(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
|
||||
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
B, L = inputs.shape
|
||||
|
||||
hidden_states = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
return self.ln_f(hidden_states)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = GPTBigCodeModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.transformer.wte.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,227 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs, create_additive_causal_mask
|
||||
|
||||
# Based on the transformers implementation at:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
max_position_embeddings: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
layer_norm_eps: float
|
||||
vocab_size: int
|
||||
rotary_emb_base: int
|
||||
rotary_pct: float
|
||||
num_key_value_heads: int = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
args.hidden_size % args.num_attention_heads == 0
|
||||
), "hidden_size must be divisible by num_attention_heads"
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
dims=int(self.head_dim * args.rotary_pct),
|
||||
traditional=False,
|
||||
base=args.rotary_emb_base,
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = nn.Linear(
|
||||
self.hidden_size, 3 * self.hidden_size, bias=True
|
||||
)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.query_key_value(x)
|
||||
|
||||
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
|
||||
qkv = qkv.reshape(*new_qkv_shape)
|
||||
|
||||
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
# gelu_approx corresponds to FastGELUActivation in transformers.
|
||||
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.layer_norm_eps = args.layer_norm_eps
|
||||
self.attention = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
self.hidden_size,
|
||||
eps=self.layer_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
self.hidden_size, eps=self.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
# NeoX runs attention and feedforward network in parallel.
|
||||
attn = self.attention(self.input_layernorm(x), mask, cache)
|
||||
ffn = self.mlp(self.post_attention_layernorm(x))
|
||||
out = attn + ffn + residual
|
||||
return out
|
||||
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
self.layer_norm_eps = args.layer_norm_eps
|
||||
assert self.vocab_size > 0
|
||||
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
|
||||
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
|
||||
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
_, L = inputs.shape
|
||||
|
||||
hidden_states = self.embed_in(inputs)
|
||||
|
||||
mask = None
|
||||
if hidden_states.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
hidden_states.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(hidden_states.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
hidden_states = layer(hidden_states, mask, cache=c)
|
||||
|
||||
out = self.final_layer_norm(hidden_states)
|
||||
out = self.embed_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GPTNeoXModel(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
new_weights = {}
|
||||
|
||||
for w_key, w_value in weights.items():
|
||||
# Created through register_buffer in Pytorch, not needed here.
|
||||
ignore_suffixes = [
|
||||
".attention.bias",
|
||||
".attention.masked_bias",
|
||||
".attention.rotary_emb.inv_freq",
|
||||
]
|
||||
|
||||
skip_weight = False
|
||||
for ignored_suffix in ignore_suffixes:
|
||||
if w_key.endswith(ignored_suffix):
|
||||
skip_weight = True
|
||||
break
|
||||
|
||||
if skip_weight:
|
||||
continue
|
||||
|
||||
if not w_key.startswith("model."):
|
||||
w_key = f"model.{w_key}"
|
||||
|
||||
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
|
||||
w_key = w_key.replace(".gpt_neox.", ".")
|
||||
|
||||
new_weights[w_key] = w_value
|
||||
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,247 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
bias: bool = True
|
||||
max_position_embeddings: int = 32768
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_base = base
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
self.scale = scale
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.original_base * (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
else:
|
||||
base = self.original_base
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_kv_groups = n_heads // args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
|
||||
)
|
||||
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 2.0
|
||||
)
|
||||
|
||||
self.rope = DynamicNTKScalingRoPE(
|
||||
head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv_states = self.wqkv(x)
|
||||
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
|
||||
|
||||
queries = qkv_states[..., : self.n_kv_groups, :]
|
||||
queries = queries.reshape(B, L, -1, self.head_dim)
|
||||
keys = qkv_states[..., -2, :]
|
||||
values = qkv_states[..., -1, :]
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.feed_forward(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class InternLM2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
assert args.vocab_size > 0
|
||||
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = InternLM2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.tok_embeddings.as_linear(out)
|
||||
else:
|
||||
out = self.output(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,329 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_additive_causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
head_dim: Optional[int] = None
|
||||
max_position_embeddings: Optional[int] = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
if not "factor" in self.rope_scaling:
|
||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
||||
"rope_type"
|
||||
)
|
||||
if rope_type is None:
|
||||
raise ValueError(
|
||||
f"rope_scaling must contain either 'type' or 'rope_type'"
|
||||
)
|
||||
if rope_type not in ["linear", "dynamic", "llama3"]:
|
||||
raise ValueError(
|
||||
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
|
||||
)
|
||||
|
||||
|
||||
class DynamicNTKScalingRoPE(nn.Module):
|
||||
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
traditional: bool = False,
|
||||
base: float = 10000,
|
||||
scale: float = 1.0,
|
||||
rope_type: str = "default",
|
||||
rope_scaling: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.traditional = traditional
|
||||
self.original_base = base
|
||||
self.scale = scale
|
||||
self.rope_type = rope_type
|
||||
self.rope_scaling = rope_scaling
|
||||
self.base = self.compute_base_freq()
|
||||
|
||||
def compute_base_freq(self):
|
||||
if self.rope_type == "llama3":
|
||||
return self.compute_llama3_base_freq()
|
||||
return self.original_base
|
||||
|
||||
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
|
||||
def compute_llama3_base_freq(self):
|
||||
factor = self.rope_scaling["factor"]
|
||||
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.rope_scaling.get(
|
||||
"original_max_position_embeddings",
|
||||
8192,
|
||||
)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||
wavelens = 2 * mx.pi * freqs
|
||||
new_base_freqs = []
|
||||
|
||||
smooths = (wavelens - high_freq_wavelen) / (
|
||||
low_freq_wavelen - high_freq_wavelen
|
||||
)
|
||||
new_base_freqs = freqs * (1 - smooths) * factor + smooths
|
||||
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
|
||||
new_base_freqs = mx.where(
|
||||
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
|
||||
)
|
||||
return new_base_freqs.mean().item()
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"{self.dims}, traditional={self.traditional}, "
|
||||
f"max_position_embeddings={self.max_position_embeddings}, "
|
||||
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
|
||||
)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
base = self.base
|
||||
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
|
||||
base *= (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
def initialize_rope(args: ModelArgs):
|
||||
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
|
||||
|
||||
rope_scaling = args.rope_scaling
|
||||
rope_type = "default"
|
||||
rope_scale = 1.0
|
||||
|
||||
if rope_scaling is not None:
|
||||
rope_type = (
|
||||
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
|
||||
)
|
||||
if rope_type == "linear":
|
||||
rope_scale = 1 / rope_scaling["factor"]
|
||||
elif rope_type == "llama3":
|
||||
rope_scale = 1.0 # The scaling is handled internally for llama3
|
||||
|
||||
return DynamicNTKScalingRoPE(
|
||||
dims=head_dim,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
rope_type=rope_type,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
if hasattr(args, "attention_bias"):
|
||||
attention_bias = args.attention_bias
|
||||
else:
|
||||
attention_bias = False
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||
|
||||
self.rope = initialize_rope(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.intermediate_size
|
||||
if hasattr(args, "mlp_bias"):
|
||||
mlp_bias = args.mlp_bias
|
||||
else:
|
||||
mlp_bias = False
|
||||
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
h.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,216 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
dim_model_base: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int
|
||||
scale_depth: float
|
||||
scale_emb: float
|
||||
rope_theta: float = 1000000.0
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = n_heads = args.num_attention_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
dims=self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=self.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
attn_output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
self.scale_depth = args.scale_depth
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
return out
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
assert self.vocab_size > 0
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = MiniCPMModel(args)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
|
||||
if not self.args.tie_word_embeddings:
|
||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||
else:
|
||||
out = out @ self.model.embed_tokens.weight.T
|
||||
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "lm_head.weight" not in weights:
|
||||
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,227 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 14336
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_key_value_heads: int = 8
|
||||
num_local_experts: int = 8
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 1e6
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.num_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MixtralSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_dim = args.hidden_size
|
||||
self.ffn_dim = args.intermediate_size
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
|
||||
# gating
|
||||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
||||
|
||||
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
|
||||
self.self_attn = MixtralAttention(args)
|
||||
|
||||
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = MixtralModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(
|
||||
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
|
||||
)
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
|
||||
mx.stack(to_join)
|
||||
)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,185 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from sys import exit
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
d_model: int
|
||||
n_layers: int
|
||||
mlp_hidden_size: int
|
||||
n_heads: int
|
||||
vocab_size: int
|
||||
embedding_size: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
mlp_ratio: int = 4
|
||||
weight_tying: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.mlp_hidden_size = (
|
||||
self.mlp_hidden_size
|
||||
if self.mlp_hidden_size is not None
|
||||
else self.mlp_ratio * self.d_model
|
||||
)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
dim = args.d_model
|
||||
|
||||
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
||||
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
||||
|
||||
self.att_norm = nn.LayerNorm(dim, affine=False)
|
||||
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
||||
|
||||
head_dim = dim // self.n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
|
||||
self.attn_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
def attend(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.attn_out(output)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.attend(self.att_norm(x), mask, cache)
|
||||
h = x + r
|
||||
|
||||
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||
|
||||
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_layers = args.n_layers
|
||||
self.weight_tying = args.weight_tying
|
||||
|
||||
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
if not self.weight_tying:
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
||||
for block, c in zip(self.blocks, cache):
|
||||
h = block(h, mask, c)
|
||||
|
||||
h = self.norm(h)
|
||||
|
||||
if self.weight_tying:
|
||||
return self.wte.as_linear(h), cache
|
||||
|
||||
return self.ff_out(h)
|
||||
|
||||
|
||||
class OlmoModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.transformer = Transformer(args)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
return self.transformer(inputs, cache)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = OlmoModel(args)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
return self.model(inputs, cache)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.transformer.blocks
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.n_heads
|
||||
@@ -1,229 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
head_dim: int
|
||||
num_transformer_layers: int
|
||||
model_dim: int
|
||||
vocab_size: int
|
||||
ffn_dim_divisor: int
|
||||
num_query_heads: List
|
||||
num_kv_heads: List
|
||||
ffn_multipliers: List
|
||||
ffn_with_glu: bool = True
|
||||
normalize_qk_projections: bool = True
|
||||
share_input_output_layers: bool = True
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_freq_constant: float = 10000
|
||||
|
||||
|
||||
def make_divisible(
|
||||
v: Union[float, int],
|
||||
divisor: Optional[int] = 8,
|
||||
min_value: Optional[Union[float, int]] = None,
|
||||
) -> Union[float, int]:
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by the divisor
|
||||
It can be seen at:
|
||||
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
|
||||
Args:
|
||||
v: input value
|
||||
divisor: default to 8
|
||||
min_value: minimum divisor value
|
||||
Returns:
|
||||
new_v: new divisible value
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim = args.head_dim
|
||||
self.layer_id = layer_id
|
||||
self.model_dim = model_dim = args.model_dim
|
||||
|
||||
self.n_heads = n_heads = args.num_query_heads[layer_id]
|
||||
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
|
||||
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
|
||||
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
|
||||
|
||||
self.normalize_qk_projections = args.normalize_qk_projections
|
||||
|
||||
if self.normalize_qk_projections:
|
||||
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
|
||||
qkv = qkv.reshape(
|
||||
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
if self.normalize_qk_projections:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
dim = args.model_dim
|
||||
ffn_multiplier = args.ffn_multipliers[layer_id]
|
||||
|
||||
intermediate_dim = int(
|
||||
make_divisible(
|
||||
ffn_multiplier * args.model_dim,
|
||||
divisor=args.ffn_dim_divisor,
|
||||
)
|
||||
)
|
||||
|
||||
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
|
||||
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.proj_1(x)
|
||||
gate, x = mx.split(x, 2, axis=-1)
|
||||
return self.proj_2(nn.silu(gate) * x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_id: int):
|
||||
super().__init__()
|
||||
dim = args.model_dim
|
||||
self.attn = Attention(args, layer_id=layer_id)
|
||||
self.ffn = MLP(args, layer_id=layer_id)
|
||||
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.attn_norm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.ffn(self.ffn_norm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class OpenELMModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_transformer_layers = args.num_transformer_layers
|
||||
assert self.vocab_size > 0
|
||||
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
|
||||
self.layers = [
|
||||
TransformerBlock(args, layer_id=layer_id)
|
||||
for layer_id in range(self.num_transformer_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.token_embeddings(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.transformer = OpenELMModel(args)
|
||||
if not args.share_input_output_layers:
|
||||
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.transformer(inputs, cache)
|
||||
if self.args.share_input_output_layers:
|
||||
out = self.transformer.token_embeddings.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_kv_heads
|
||||
@@ -1,180 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str = "phi"
|
||||
max_position_embeddings: int = 2048
|
||||
vocab_size: int = 51200
|
||||
hidden_size: int = 2560
|
||||
num_attention_heads: int = 32
|
||||
num_hidden_layers: int = 32
|
||||
num_key_value_heads: int = 32
|
||||
partial_rotary_factor: float = 0.4
|
||||
intermediate_size: int = 10240
|
||||
layer_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
||||
)
|
||||
self.dense = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=True
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(
|
||||
B,
|
||||
L,
|
||||
n_heads,
|
||||
-1,
|
||||
).moveaxis(1, 2)
|
||||
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||
).astype(values.dtype)
|
||||
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = PhiAttention(config=config)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = PhiMLP(config)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
ff_h = self.mlp(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
x = self.embed_tokens(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, c)
|
||||
return self.final_layernorm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = PhiModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
y = self.model(x, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,215 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
|
||||
max_position_embeddings: int = 131072
|
||||
original_max_position_embeddings: int = 4096
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"long_factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
|
||||
print(
|
||||
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
|
||||
)
|
||||
self.rope_scaling = None
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
|
||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = 1.0
|
||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||
self.rope = SuScaledRotaryEmbedding(
|
||||
head_dim,
|
||||
traditional=False,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
max_position_embeddings=args.max_position_embeddings,
|
||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||
short_factor=args.rope_scaling["short_factor"],
|
||||
long_factor=args.rope_scaling["long_factor"],
|
||||
)
|
||||
else:
|
||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||
assert isinstance(args.rope_scaling["factor"], float)
|
||||
rope_scale = 1 / args.rope_scaling["factor"]
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
query_pos = self.n_heads * self.head_dim
|
||||
queries, keys, values = mx.split(
|
||||
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.gate_up_proj(x)
|
||||
gate, x = mx.split(x, 2, axis=-1)
|
||||
return self.down_proj(nn.silu(gate) * x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,320 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
dense_attention_every_n_layers: int
|
||||
ff_intermediate_size: int
|
||||
gegelu_limit: float
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
mup_attn_multiplier: float = 1.0
|
||||
mup_use_scaling: bool = True
|
||||
mup_embedding_multiplier: float = 10.0
|
||||
mup_width_multiplier: float = 8.0
|
||||
rope_embedding_base: float = 1000000
|
||||
rope_position_scale: float = 1.0
|
||||
blocksparse_block_size: Tuple[int] = (64,)
|
||||
blocksparse_num_local_blocks: int = 16
|
||||
blocksparse_vert_stride: int = 8
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def gegelu_impl(a_gelu, a_linear, limit):
|
||||
a_gelu = mx.where(
|
||||
mx.isinf(a_gelu),
|
||||
a_gelu,
|
||||
mx.clip(a_gelu, a_min=None, a_max=limit),
|
||||
)
|
||||
a_linear = mx.where(
|
||||
mx.isinf(a_linear),
|
||||
a_linear,
|
||||
mx.clip(a_linear, a_min=-limit, a_max=limit),
|
||||
)
|
||||
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
|
||||
return out_gelu * (a_linear + 1.0)
|
||||
|
||||
|
||||
def gegelu(x, limit):
|
||||
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
|
||||
return gegelu_impl(a_gelu, a_linear, limit)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_q_per_kv = n_heads // n_kv_heads
|
||||
|
||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||
|
||||
self.query_key_value = nn.Linear(
|
||||
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
|
||||
)
|
||||
self.dense = nn.Linear(dim, dim)
|
||||
|
||||
if args.mup_use_scaling:
|
||||
norm_factor = head_dim / args.mup_attn_multiplier
|
||||
else:
|
||||
norm_factor = math.sqrt(head_dim)
|
||||
self.scale = 1.0 / norm_factor
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=False,
|
||||
base=args.rope_embedding_base,
|
||||
scale=args.rope_position_scale,
|
||||
)
|
||||
|
||||
if layer_idx % args.dense_attention_every_n_layers == 0:
|
||||
self.block_sparse = True
|
||||
self.blocksparse_block_size = args.blocksparse_block_size
|
||||
if self.blocksparse_block_size not in (32, 64):
|
||||
raise ValueError(
|
||||
f"Unsupported block size {self.blocksparse_block_size}"
|
||||
)
|
||||
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
|
||||
self.blocksparse_vert_stride = args.blocksparse_vert_stride
|
||||
else:
|
||||
self.block_sparse = False
|
||||
|
||||
def _block_sparse_mask(self, q_len, kv_len):
|
||||
vert_stride = self.blocksparse_vert_stride
|
||||
local_blocks = self.blocksparse_num_local_blocks
|
||||
block_size = self.blocksparse_block_size
|
||||
n_heads = self.n_heads
|
||||
|
||||
kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
q_blocks = (q_len + block_size - 1) // block_size
|
||||
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
|
||||
k_pos = mx.arange(kv_blocks)[None, None]
|
||||
|
||||
mask_vert_strided = (
|
||||
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
|
||||
) % vert_stride
|
||||
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
|
||||
|
||||
block_mask = (q_pos >= k_pos) & (
|
||||
(q_pos - k_pos < local_blocks) | mask_vert_strided
|
||||
)
|
||||
block_mask = block_mask.reshape(
|
||||
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
|
||||
)
|
||||
dense_mask = mx.repeat(
|
||||
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
|
||||
)
|
||||
return block_mask, dense_mask[..., -q_len:, :kv_len]
|
||||
|
||||
def _block_sparse_attention(self, queries, keys, values, scale, mask):
|
||||
queries = scale * queries
|
||||
B = queries.shape[0]
|
||||
L = queries.shape[2]
|
||||
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
|
||||
keys = mx.expand_dims(keys, 2)
|
||||
values = mx.expand_dims(values, 2)
|
||||
|
||||
# TODO get rid of dense mask if we have a fill value
|
||||
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
|
||||
scores = queries @ mx.swapaxes(keys, -1, -2)
|
||||
# TODO, uncomment when faster
|
||||
# scores = mx.block_masked_mm(
|
||||
# queries,
|
||||
# mx.swapaxes(keys, -1, -2),
|
||||
# mask_out=block_mask,
|
||||
# block_size=self.blocksparse_block_size,
|
||||
# )
|
||||
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = scores + mx.where(
|
||||
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
|
||||
)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
output = scores @ values
|
||||
# TODO, uncomment when faster
|
||||
# output = mx.block_masked_mm(
|
||||
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
|
||||
# )
|
||||
return mx.reshape(output, (B, self.n_heads, L, -1))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
qkv = self.query_key_value(x)
|
||||
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
|
||||
queries = qkv[..., :-2, :].flatten(-3, -2)
|
||||
keys = qkv[..., -2, :]
|
||||
values = qkv[..., -1, :]
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
if self.block_sparse:
|
||||
output = self._block_sparse_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.dense(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
hidden_dim = args.ff_intermediate_size
|
||||
self.gegelu_limit = args.gegelu_limit
|
||||
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
x = self.up_proj(x)
|
||||
return self.down_proj(gegelu(x, self.gegelu_limit))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs, layer_idx):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args, layer_idx)
|
||||
self.mlp = MLP(args)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_epsilon
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layer_norm_epsilon,
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.mup_embedding_multiplier = args.mup_embedding_multiplier
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args, layer_idx=l)
|
||||
for l in range(args.num_hidden_layers)
|
||||
]
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_epsilon
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
if self.mup_embedding_multiplier:
|
||||
h = self.mup_embedding_multiplier * h
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.final_layernorm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
self.args = args
|
||||
self.mup_width_multiplier = args.mup_width_multiplier
|
||||
self._dummy_tokenizer_ids = mx.array(
|
||||
[100256, 100258, 100259, 100260, 100264, 100265]
|
||||
+ list(range(100267, 100352))
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
if self.mup_width_multiplier:
|
||||
out = out / self.mup_width_multiplier
|
||||
out[self._dummy_tokenizer_ids] = -float("inf")
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,203 +0,0 @@
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .switch_layers import SwitchMLP
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
model_type: str
|
||||
num_vocab: int = 51200
|
||||
model_dim: int = 2560
|
||||
num_heads: int = 32
|
||||
num_layers: int = 32
|
||||
rotary_dim: int = 32
|
||||
num_experts_per_tok: int = 2
|
||||
num_local_experts: int = 4
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.rope = nn.RoPE(rotary_dim, traditional=False)
|
||||
self.Wqkv = nn.Linear(dims, 3 * dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.Wqkv(x)
|
||||
queries, keys, values = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
# Extract some shapes
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||
).astype(values.dtype)
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class MOE(nn.Module):
|
||||
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_experts = args.num_local_experts
|
||||
self.num_experts_per_tok = args.num_experts_per_tok
|
||||
self.switch_mlp = SwitchMLP(
|
||||
self.dim, self.hidden_dim, self.num_experts, bias=True
|
||||
)
|
||||
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gates = self.gate(x)
|
||||
|
||||
k = self.num_experts_per_tok
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.moe = MOE(config, dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.ln(x)
|
||||
attn_h = self.mixer(h, mask, cache)
|
||||
ff_h = self.moe(h)
|
||||
return attn_h + ff_h + x
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embd = Embd(config)
|
||||
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embd(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
return x
|
||||
|
||||
|
||||
class Embd(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.wte(x)
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(self.ln(inputs))
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_layers):
|
||||
prefix = f"transformer.h.{l}"
|
||||
for n in ["fc1", "fc2"]:
|
||||
for k in ["weight", "scales", "biases", "bias"]:
|
||||
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_local_experts)
|
||||
]
|
||||
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.model_dim // self.args.num_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_heads
|
||||
@@ -1,216 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
n_shared_head: int = 8
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
head_dim = self.hidden_size // config.num_attention_heads
|
||||
|
||||
self.q_num_heads = config.num_attention_heads
|
||||
self.qk_dim = self.v_dim = head_dim
|
||||
self.k_num_heads = self.v_num_heads = int(
|
||||
np.ceil(self.q_num_heads / config.n_shared_head)
|
||||
)
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.q_num_heads * self.qk_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.k_num_heads * self.qk_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.v_num_heads * self.v_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
|
||||
)
|
||||
self.rotary_emb = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
base=config.rope_theta,
|
||||
scale=1.0,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
|
||||
|
||||
|
||||
class PlamoDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[Any, ...]:
|
||||
# from LlamaDecoder
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states_sa = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_mlp = self.mlp(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PlamoDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
class PlamoModel(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
|
||||
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
|
||||
for layer, c in zip(self.layers.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.model = PlamoModel(args)
|
||||
self.lm_head: nn.Module = nn.Linear(
|
||||
args.hidden_size, args.vocab_size, bias=False
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads // self.args.n_shared_head
|
||||
@@ -1,169 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int = 2048
|
||||
num_attention_heads: int = 16
|
||||
num_hidden_layers: int = 24
|
||||
kv_channels: int = 128
|
||||
max_position_embeddings: int = 8192
|
||||
layer_norm_epsilon: float = 1e-6
|
||||
intermediate_size: int = 11008
|
||||
no_bias: bool = True
|
||||
vocab_size: int = 151936
|
||||
num_key_value_heads = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = args.hidden_size
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
|
||||
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
|
||||
|
||||
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
|
||||
|
||||
proj_size = args.kv_channels * self.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
|
||||
|
||||
self.scale = hidden_size_per_attention_head**-0.5
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
qkv = self.c_attn(x)
|
||||
|
||||
q, k, v = mx.split(qkv, 3, axis=-1)
|
||||
|
||||
B, L, _ = q.shape
|
||||
|
||||
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rotary_emb(queries)
|
||||
keys = self.rotary_emb(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.c_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
|
||||
)
|
||||
self.c_proj = nn.Linear(
|
||||
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
a1 = self.w1(x)
|
||||
a2 = self.w2(x)
|
||||
return self.c_proj(a1 * nn.silu(a2))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
residual = x
|
||||
x = self.ln_1(x)
|
||||
x = self.attn(x, mask=mask, cache=cache)
|
||||
residual = x + residual
|
||||
x = self.ln_2(residual)
|
||||
x = self.mlp(x)
|
||||
x = x + residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class QwenModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = None
|
||||
T = x.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
||||
for layer, c in zip(self.h, cache):
|
||||
x = layer(x, mask, c)
|
||||
|
||||
return self.ln_f(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.transformer = QwenModel(config)
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||
)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads
|
||||
@@ -1,207 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.args.tie_word_embeddings:
|
||||
weights.pop("lm_head.weight", None)
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,247 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_experts_per_tok: int
|
||||
num_experts: int
|
||||
moe_intermediate_size: int
|
||||
shared_expert_intermediate_size: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 1000000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
dim = args.hidden_size
|
||||
intermediate_size = args.moe_intermediate_size
|
||||
shared_expert_intermediate_size = args.shared_expert_intermediate_size
|
||||
|
||||
self.num_experts = num_experts = args.num_experts
|
||||
self.top_k = args.num_experts_per_tok
|
||||
|
||||
self.gate = nn.Linear(dim, num_experts, bias=False)
|
||||
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
|
||||
|
||||
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
|
||||
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
):
|
||||
gates = self.gate(x)
|
||||
gates = mx.softmax(gates, axis=-1, precise=True)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
|
||||
scores = mx.take_along_axis(gates, inds, axis=-1)
|
||||
|
||||
y = self.switch_mlp(x, inds)
|
||||
y = (y * scores[..., None]).sum(axis=-2)
|
||||
|
||||
shared_expert_output = self.shared_expert(x)
|
||||
shared_expert_output = (
|
||||
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
|
||||
)
|
||||
|
||||
return y + shared_expert_output
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(args)
|
||||
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2MoeModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2MoeModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
def sanitize(self, weights):
|
||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||
return weights
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n in ["up_proj", "down_proj", "gate_proj"]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
|
||||
to_join = [
|
||||
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
|
||||
for e in range(self.args.num_experts)
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,518 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
attention_bias: bool
|
||||
conv1d_width: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
logits_soft_cap: float
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
attention_window_size: int
|
||||
vocab_size: int
|
||||
embeddings_scale_by_sqrt_dim: bool = True
|
||||
block_types: Optional[List[str]] = None
|
||||
_block_types: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# For some reason these have different names in 2B and 9B
|
||||
if self.block_types is None:
|
||||
self.block_types = self._block_types
|
||||
|
||||
|
||||
def create_window_causal_mask(N: int, window_size: int):
|
||||
inds = mx.arange(N)
|
||||
linds = inds[:, None]
|
||||
rinds = inds[None]
|
||||
mask = (linds < rinds) | (linds > rinds + window_size)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class RecurrentCache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache = (None, None)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._cache[idx]
|
||||
|
||||
def update(self, conv_state, recurrent_state):
|
||||
self._cache = (conv_state, recurrent_state)
|
||||
|
||||
|
||||
class WindowKVCache:
|
||||
|
||||
def __init__(self, window_size):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.window_size = window_size
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
# TODO consider using rotating buffer here
|
||||
# especially for very long generations
|
||||
def _update(x, v):
|
||||
t = x.shape[2] - self.window_size
|
||||
if t > 0:
|
||||
x = x[..., t:, :]
|
||||
return mx.concatenate([x, v], axis=2)
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
self.keys = _update(self.keys, keys)
|
||||
self.values = _update(self.values, values)
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
def rnn_scan(x, a, h0):
|
||||
assert x.ndim == 3
|
||||
assert a.shape == x.shape[-a.ndim :]
|
||||
assert a.dtype == x.dtype
|
||||
|
||||
if x.shape[1] == 1:
|
||||
# Using scan in sampling mode.
|
||||
if h0 is None:
|
||||
return x, x[:, 0]
|
||||
|
||||
else:
|
||||
y = a * h0[:, None] + x
|
||||
return y, y[:, -1]
|
||||
|
||||
else:
|
||||
# Using scan in linear mode.
|
||||
if h0 is not None:
|
||||
h_t = h0
|
||||
else:
|
||||
B, _, D = x.shape
|
||||
h_t = mx.zeros((B, D), dtype=x.dtype)
|
||||
|
||||
y = mx.zeros_like(x)
|
||||
for t in range(x.shape[1]):
|
||||
h_t = a[:, t] * h_t + x[:, t]
|
||||
y[:, t] = h_t
|
||||
|
||||
return y, h_t
|
||||
|
||||
|
||||
class Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((kernel_size, channels))
|
||||
self.bias = mx.zeros((channels,))
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
w = self.weight.T[..., None]
|
||||
kw, groups = self.weight.shape
|
||||
if cache is not None:
|
||||
l = []
|
||||
# Pad the cache if needed
|
||||
if cache.shape[1] < kw - 1:
|
||||
l.append(
|
||||
mx.zeros(
|
||||
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
|
||||
)
|
||||
)
|
||||
l.extend([cache, x])
|
||||
x = mx.concatenate(l, axis=1)
|
||||
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
|
||||
else:
|
||||
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
|
||||
|
||||
# The cache is always kw - 1
|
||||
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
|
||||
y = y + self.bias
|
||||
return y, cache
|
||||
|
||||
|
||||
class RGLRU(nn.Module):
|
||||
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.width // self.num_heads
|
||||
|
||||
self.recurrent_param = mx.zeros((self.width,))
|
||||
|
||||
self.input_gate_weight = mx.zeros(
|
||||
(self.num_heads, self.head_dim, self.head_dim),
|
||||
)
|
||||
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
||||
|
||||
self.recurrent_gate_weight = mx.zeros(
|
||||
(self.num_heads, self.head_dim, self.head_dim),
|
||||
)
|
||||
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
def apply_block_linear(h, w, b):
|
||||
h = h.reshape((B, L, self.num_heads, self.head_dim))
|
||||
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
|
||||
return mx.sigmoid(h.flatten(2, 3))
|
||||
|
||||
# Gates for x and a.
|
||||
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
|
||||
gate_a = apply_block_linear(
|
||||
x, self.recurrent_gate_weight, self.recurrent_gate_bias
|
||||
)
|
||||
|
||||
# Compute the parameter `A` of the recurrence.
|
||||
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
|
||||
a = mx.exp(log_a)
|
||||
a_square = mx.exp(2 * log_a)
|
||||
|
||||
# Gate the input.
|
||||
gated_x = x * gate_x
|
||||
|
||||
# Apply gamma normalization to the input.
|
||||
multiplier = mx.sqrt(1 - a_square)
|
||||
if cache is None:
|
||||
multiplier[:, 0, :] = 1.0
|
||||
normalized_x = gated_x * multiplier.astype(x.dtype)
|
||||
|
||||
y, last_h = rnn_scan(
|
||||
x=normalized_x,
|
||||
a=a,
|
||||
h0=cache,
|
||||
)
|
||||
|
||||
return y, last_h
|
||||
|
||||
|
||||
class RecurrentBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
lru_width: int = None,
|
||||
conv1d_temporal_width: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.lru_width = lru_width or width
|
||||
self.conv1d_temporal_width = conv1d_temporal_width
|
||||
|
||||
self.linear_y = nn.Linear(width, self.lru_width)
|
||||
self.linear_x = nn.Linear(width, self.lru_width)
|
||||
self.linear_out = nn.Linear(self.lru_width, width)
|
||||
self.conv_1d = Conv1d(
|
||||
channels=self.lru_width,
|
||||
kernel_size=self.conv1d_temporal_width,
|
||||
)
|
||||
self.rg_lru = RGLRU(
|
||||
width=self.lru_width,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
# y branch.
|
||||
y = self.linear_y(x)
|
||||
y = nn.gelu_approx(y)
|
||||
|
||||
# x branch.
|
||||
x = self.linear_x(x)
|
||||
if cache is None:
|
||||
conv_state, recurrent_state = (None, None)
|
||||
else:
|
||||
conv_state, recurrent_state = cache[0], cache[1]
|
||||
x, conv_state = self.conv_1d(
|
||||
x=x,
|
||||
cache=conv_state,
|
||||
)
|
||||
x, recurrent_state = self.rg_lru(
|
||||
x=x,
|
||||
cache=recurrent_state,
|
||||
)
|
||||
if cache is not None:
|
||||
cache.update(conv_state, recurrent_state)
|
||||
|
||||
x = x * y
|
||||
x = self.linear_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LocalAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = (width // num_heads) ** (-0.5)
|
||||
|
||||
self.head_dim = self.width // self.num_heads
|
||||
self.q_proj = nn.Linear(self.width, self.width, bias=False)
|
||||
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.width, self.width, bias=True)
|
||||
self.rope = nn.RoPE(
|
||||
self.head_dim // 2,
|
||||
traditional=False,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
|
||||
def __init__(self, width: int, expanded_width: int):
|
||||
super().__init__()
|
||||
self.up_proj = nn.Linear(width, expanded_width // 2)
|
||||
self.gate_proj = nn.Linear(width, expanded_width // 2)
|
||||
self.down_proj = nn.Linear(expanded_width // 2, width)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
gate = self.gate_proj(x)
|
||||
x = self.up_proj(x)
|
||||
return self.down_proj(nn.gelu_approx(gate) * x)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
mlp_expanded_width: int,
|
||||
num_heads: int,
|
||||
attention_window_size: int,
|
||||
temporal_block_type: str,
|
||||
lru_width: Optional[int] = None,
|
||||
conv1d_temporal_width: int = 4,
|
||||
):
|
||||
"""Initializes the residual block.
|
||||
|
||||
Args:
|
||||
width: The width of the block.
|
||||
mlp_expanded_width: The width of the expansion inside the MLP block.
|
||||
num_heads: The number of heads for the Attention or the RG-LRU.
|
||||
attention_window_size: The window size for the local attention block.
|
||||
temporal_block_type: Either "recurrent" or "attention", specifying the
|
||||
type of recurrent block to use.
|
||||
lru_width: The width of the RG-LRU if different from `width`.
|
||||
conv1d_temporal_width: The width of the temporal convolution.
|
||||
"""
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.mlp_expanded_width = mlp_expanded_width
|
||||
self.num_heads = num_heads
|
||||
self.attention_window_size = attention_window_size
|
||||
self.temporal_block_type = temporal_block_type
|
||||
self.lru_width = lru_width
|
||||
self.conv1d_temporal_width = conv1d_temporal_width
|
||||
|
||||
self.temporal_pre_norm = RMSNorm(width)
|
||||
if self.temporal_block_type == "recurrent":
|
||||
self.temporal_block = RecurrentBlock(
|
||||
width=self.width,
|
||||
num_heads=self.num_heads,
|
||||
lru_width=self.lru_width,
|
||||
conv1d_temporal_width=self.conv1d_temporal_width,
|
||||
)
|
||||
|
||||
else:
|
||||
self.temporal_block = LocalAttentionBlock(
|
||||
width=self.width,
|
||||
num_heads=self.num_heads,
|
||||
window_size=self.attention_window_size,
|
||||
)
|
||||
|
||||
self.channel_pre_norm = RMSNorm(width)
|
||||
self.mlp_block = MLPBlock(
|
||||
width=self.width,
|
||||
expanded_width=self.mlp_expanded_width,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache=None,
|
||||
mask=None,
|
||||
):
|
||||
raw_x = x
|
||||
|
||||
inputs_normalized = self.temporal_pre_norm(raw_x)
|
||||
|
||||
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
||||
residual = x + raw_x
|
||||
|
||||
x = self.channel_pre_norm(residual)
|
||||
x = self.mlp_block(x)
|
||||
|
||||
x = x + residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Griffin(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
||||
block_types = config.block_types
|
||||
|
||||
self.layers = [
|
||||
ResidualBlock(
|
||||
width=config.hidden_size,
|
||||
mlp_expanded_width=config.intermediate_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
attention_window_size=config.attention_window_size,
|
||||
temporal_block_type=block_types[i % len(block_types)],
|
||||
lru_width=None,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tokens,
|
||||
cache=None,
|
||||
):
|
||||
x = self.embed_tokens(tokens)
|
||||
if self.scale_by_sqrt_dim:
|
||||
x = x * math.sqrt(x.shape[-1])
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = create_window_causal_mask(
|
||||
x.shape[1], self.config.attention_window_size
|
||||
)
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
|
||||
return self.final_norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
self.args = config
|
||||
self.model = Griffin(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
tokens: Sequence of input tokens.
|
||||
"""
|
||||
logits = self.model(tokens, cache=cache)
|
||||
if "lm_head" in self:
|
||||
logits = self.lm_head(logits)
|
||||
else:
|
||||
logits = self.model.embed_tokens.as_linear(logits)
|
||||
|
||||
c = self.args.logits_soft_cap
|
||||
if c:
|
||||
logits = mx.tanh(logits / c) * c
|
||||
return logits
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.ndim == 3:
|
||||
weights[k] = v.squeeze(1).T
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
return weights
|
||||
|
||||
def make_cache(self):
|
||||
cache = []
|
||||
for layer in self.layers:
|
||||
if layer.temporal_block_type == "recurrent":
|
||||
cache.append(RecurrentCache())
|
||||
else:
|
||||
cache.append(WindowKVCache(self.args.attention_window_size))
|
||||
return cache
|
||||
@@ -1,219 +0,0 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
intermediate_size: int
|
||||
rope_theta: float
|
||||
use_qkv_bias: bool
|
||||
partial_rotary_factor: float
|
||||
layer_norm_eps: float
|
||||
use_parallel_residual: bool = False
|
||||
qk_layernorm: bool = False
|
||||
|
||||
|
||||
class LayerNormPerHead(nn.Module):
|
||||
|
||||
def __init__(self, head_dim, num_heads, eps):
|
||||
super().__init__()
|
||||
self.norms = [
|
||||
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
|
||||
]
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
w = mx.stack([n.weight for n in self.norms])
|
||||
return w * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
bias=config.use_qkv_bias,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.num_key_value_heads * self.head_dim,
|
||||
bias=config.use_qkv_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
traditional=False,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
if self.qk_layernorm:
|
||||
self.q_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
self.k_layernorm = LayerNormPerHead(
|
||||
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Extract some shapes
|
||||
B, L, D = queries.shape
|
||||
|
||||
queries = queries.reshape(B, L, self.num_heads, -1)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
|
||||
if self.qk_layernorm:
|
||||
queries = self.q_layernorm(queries)
|
||||
keys = self.k_layernorm(keys)
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
# Add RoPE to the queries and keys and combine them with the cache
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
queries = queries.astype(mx.float32)
|
||||
keys = keys.astype(mx.float32)
|
||||
|
||||
# Finally perform the attention computation
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=scale, mask=mask
|
||||
).astype(values.dtype)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
if not self.use_parallel_residual:
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
h = self.input_layernorm(x)
|
||||
r = self.self_attn(h, mask, cache)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
out = x + r + self.mlp(h)
|
||||
else:
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class StableLM(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, mask, cache=c)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.model = StableLM(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.args = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,175 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
norm_epsilon: float = 1e-5
|
||||
vocab_size: int = 49152
|
||||
rope_theta: float = 100000
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
|
||||
head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
||||
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.c_proj(nn.gelu(self.c_fc(x)))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.hidden_size = args.hidden_size
|
||||
self.n_heads = args.num_attention_heads
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.norm_epsilon
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Starcoder2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Starcoder2Model(args)
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,79 +0,0 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class SuScaledRotaryEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
traditional: bool = False,
|
||||
base: float = 10000.0,
|
||||
scale: float = 1.0,
|
||||
max_position_embeddings: int = 131072,
|
||||
original_max_position_embeddings: int = 4096,
|
||||
short_factor: Union[List[float], float] = 1.0,
|
||||
long_factor: Union[List[float], float] = 1.0,
|
||||
):
|
||||
"""
|
||||
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
||||
|
||||
Args:
|
||||
dims (int): The feature dimensions to be rotated.
|
||||
traditional (bool, optional): Unused. Default: ``False``.
|
||||
base (int, optional): Base for the exponential scaling.
|
||||
scale (float, optional): The scale used to scale the positions.
|
||||
Default: ``1.0``.
|
||||
max_position_embeddings (int, optional): The maximum sequence
|
||||
length that this model was trained with. This is used to determine
|
||||
the size of the original RoPE embeddings when using long scaling.
|
||||
Default: ``131072``.
|
||||
original_max_position_embeddings (int, optional): The maximum
|
||||
sequence length that this model was trained with. This is used to
|
||||
determine the size of the original RoPE embeddings when using long
|
||||
scaling. Default: ``4096``.
|
||||
short_factor (float or list[float], optional): List of scaling
|
||||
factors for sequences of length lesser than
|
||||
``original_max_position_embeddings``. Default: ``1.0``.
|
||||
long_factor (float or list[float], optional): List of scaling
|
||||
factors for sequences of length greater than
|
||||
``original_max_position_embeddings``. Default: ``1.0``.
|
||||
"""
|
||||
self.inv_freq_short = 1.0 / (
|
||||
mx.array(short_factor, dtype=mx.float32)
|
||||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||
)
|
||||
self.inv_freq_long = 1.0 / (
|
||||
scale
|
||||
* mx.array(long_factor, dtype=mx.float32)
|
||||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||
)
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.scaling_factor = math.sqrt(
|
||||
1
|
||||
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||
/ math.log(original_max_position_embeddings)
|
||||
)
|
||||
|
||||
def _get_cos_sin(self, offset, L):
|
||||
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
|
||||
inv_freq = (
|
||||
self.inv_freq_long
|
||||
if (offset + L) > self.original_max_position_embeddings
|
||||
else self.inv_freq_short
|
||||
)
|
||||
freqs = position_ids[:, None] * inv_freq[None, :]
|
||||
emb = mx.concatenate([freqs, freqs], axis=-1)
|
||||
cos = mx.cos(emb) * self.scaling_factor
|
||||
sin = mx.sin(emb) * self.scaling_factor
|
||||
return cos, sin
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
def _rotate_half(_x):
|
||||
midpoint = _x.shape[-1] // 2
|
||||
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
|
||||
return mx.concatenate([-x2, x1], axis=-1)
|
||||
|
||||
cos, sin = self._get_cos_sin(offset, x.shape[2])
|
||||
return (x * cos) + (_rotate_half(x) * sin)
|
||||
@@ -1,165 +0,0 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class QuantizedSwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
),
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
|
||||
# Freeze this model's parameters
|
||||
self.freeze()
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
||||
our parameters will remain frozen."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(recurse=False)
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.scales.shape[2] * self.group_size
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.gather_qmm(
|
||||
x,
|
||||
self["weight"],
|
||||
self["scales"],
|
||||
self["biases"],
|
||||
rhs_indices=indices,
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
|
||||
class SwitchLinear(nn.Module):
|
||||
def __init__(
|
||||
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(num_experts, output_dims, input_dims),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((num_experts, output_dims))
|
||||
|
||||
@property
|
||||
def input_dims(self):
|
||||
return self.weight.shape[2]
|
||||
|
||||
@property
|
||||
def output_dims(self):
|
||||
return self.weight.shape[1]
|
||||
|
||||
@property
|
||||
def num_experts(self):
|
||||
return self.weight.shape[0]
|
||||
|
||||
def __call__(self, x, indices):
|
||||
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
|
||||
if "bias" in self:
|
||||
x = x + mx.expand_dims(self["bias"][indices], -2)
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
num_experts, output_dims, input_dims = self.weight.shape
|
||||
ql = QuantizedSwitchLinear(
|
||||
input_dims, output_dims, num_experts, False, group_size, bits
|
||||
)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
|
||||
if "bias" in self:
|
||||
ql.bias = self.bias
|
||||
return ql
|
||||
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.silu,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x_up = self.up_proj(x, indices)
|
||||
x_gate = self.gate_proj(x, indices)
|
||||
x = self.down_proj(self.activation(x_gate) * x_up, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
|
||||
|
||||
class SwitchMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
hidden_dims: int,
|
||||
num_experts: int,
|
||||
activation=nn.gelu_approx,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
|
||||
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
|
||||
self.activation = activation
|
||||
|
||||
def __call__(self, x, indices) -> mx.array:
|
||||
x = mx.expand_dims(x, (-2, -3))
|
||||
|
||||
x = self.fc1(x, indices)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x, indices)
|
||||
|
||||
return x.squeeze(-2)
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
mlx>=0.14.1
|
||||
numpy
|
||||
transformers[sentencepiece]>=4.39.3
|
||||
protobuf
|
||||
pyyaml
|
||||
jinja2
|
||||
@@ -1,34 +0,0 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
||||
"""
|
||||
Apply top-p (nucleus) sampling to logits.
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
top_p: The cumulative probability threshold for top-p filtering.
|
||||
temperature: Temperature parameter for softmax distribution reshaping.
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
"""
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits / temperature, axis=-1)
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
# select tokens with cumulative probs below threshold
|
||||
top_probs = mx.where(
|
||||
cumulative_probs > 1 - top_p,
|
||||
sorted_probs,
|
||||
mx.zeros_like(sorted_probs),
|
||||
)
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||
token = sorted_indices.squeeze(0)[sorted_token]
|
||||
|
||||
return token
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user